diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index bec89ab888..80e0c8b216 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -60,13 +60,13 @@ jobs: uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.06 with: build_type: branch - node_type: "gpu-latest-1" + node_type: "gpu-v100-latest-1" arch: "amd64" container_image: "rapidsai/ci:latest" run_script: "ci/build_docs.sh" wheel-build-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-120-pip with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -78,7 +78,7 @@ jobs: wheel-publish-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-120-pip with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -88,7 +88,7 @@ jobs: wheel-build-raft-dask: needs: wheel-publish-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-120-pip with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -100,7 +100,7 @@ jobs: wheel-publish-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-120-pip with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 8175b4fbc7..fcb155d651 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -60,14 +60,14 @@ jobs: uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.06 with: build_type: pull-request - node_type: "gpu-latest-1" + node_type: "gpu-v100-latest-1" arch: "amd64" container_image: "rapidsai/ci:latest" run_script: "ci/build_docs.sh" wheel-build-pylibraft: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-120-pip with: build_type: pull-request package-name: pylibraft @@ -76,34 +76,31 @@ jobs: wheel-tests-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-120-pip with: build_type: pull-request package-name: pylibraft - test-before-amd64: "pip install cupy-cuda11x" - # On arm also need to install cupy from the specific webpage. - test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64" - test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + test-unittest: "python -m pytest ./python/pylibraft/pylibraft/test" test-smoketest: "python ./ci/wheel_smoke_test_pylibraft.py" wheel-build-raft-dask: needs: wheel-tests-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-120-pip with: build_type: pull-request package-name: raft_dask package-dir: python/raft-dask - before-wheel: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-wheelhouse" + before-wheel: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft && python -m pip install --no-deps ./local-pylibraft/pylibraft*.whl" skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" wheel-tests-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-120-pip with: build_type: pull-request package-name: raft_dask # Always want to test against latest dask/distributed. - test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" - test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" - test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" + test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" + test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" + test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test" test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dc8f7b6f2b..d389c4e2a9 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -32,19 +32,17 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-120-pip with: build_type: nightly branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} package-name: pylibraft - test-before-amd64: "pip install cupy-cuda11x" - test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64" - test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + test-unittest: "python -m pytest ./python/pylibraft/pylibraft/test" wheel-tests-raft-dask: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.06 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-120-pip with: build_type: nightly branch: ${{ inputs.branch }} @@ -53,4 +51,4 @@ jobs: package-name: raft_dask test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" test-before-arm64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06" - test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" + test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d6e4ecb676..2a70632497 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: additional_dependencies: [toml] args: ["--config=pyproject.toml"] - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v11.1.0 + rev: v16.0.1 hooks: - id: clang-format types_or: [c, c++, cuda] diff --git a/CHANGELOG.md b/CHANGELOG.md index c4701f587f..c5ca5995e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,119 @@ +# raft 23.04.00 (6 Apr 2023) + +## 🚨 Breaking Changes + +- Pin `dask` and `distributed` for release ([#1399](https://github.com/rapidsai/raft/pull/1399)) [@galipremsagar](https://github.com/galipremsagar) +- Remove faiss_mr.hpp ([#1351](https://github.com/rapidsai/raft/pull/1351)) [@benfred](https://github.com/benfred) +- Removing FAISS from build ([#1340](https://github.com/rapidsai/raft/pull/1340)) [@cjnolet](https://github.com/cjnolet) +- Generic linalg::map ([#1337](https://github.com/rapidsai/raft/pull/1337)) [@achirkin](https://github.com/achirkin) +- Consolidate pre-compiled specializations into single `libraft` binary ([#1333](https://github.com/rapidsai/raft/pull/1333)) [@cjnolet](https://github.com/cjnolet) +- Generic linalg::map ([#1329](https://github.com/rapidsai/raft/pull/1329)) [@achirkin](https://github.com/achirkin) +- Update and standardize IVF indexes API ([#1328](https://github.com/rapidsai/raft/pull/1328)) [@viclafargue](https://github.com/viclafargue) +- IVF-Flat index splitting ([#1271](https://github.com/rapidsai/raft/pull/1271)) [@lowener](https://github.com/lowener) +- IVF-PQ: store cluster data in individual lists and reduce templates ([#1249](https://github.com/rapidsai/raft/pull/1249)) [@achirkin](https://github.com/achirkin) +- Fix for svd API ([#1190](https://github.com/rapidsai/raft/pull/1190)) [@lowener](https://github.com/lowener) +- Remove deprecated headers ([#1145](https://github.com/rapidsai/raft/pull/1145)) [@lowener](https://github.com/lowener) + +## 🐛 Bug Fixes + +- Fix primitives benchmarks ([#1389](https://github.com/rapidsai/raft/pull/1389)) [@ahendriksen](https://github.com/ahendriksen) +- Fixing index-url link on pip install docs ([#1378](https://github.com/rapidsai/raft/pull/1378)) [@cjnolet](https://github.com/cjnolet) +- Adding some functions back in that seem to be a copy/paste error ([#1373](https://github.com/rapidsai/raft/pull/1373)) [@cjnolet](https://github.com/cjnolet) +- Remove usage of Dask's `get_worker` ([#1365](https://github.com/rapidsai/raft/pull/1365)) [@pentschev](https://github.com/pentschev) +- Remove MANIFEST.in use auto-generated one for sdists and package_data for wheels ([#1348](https://github.com/rapidsai/raft/pull/1348)) [@vyasr](https://github.com/vyasr) +- Revert "Generic linalg::map ([#1329)" (#1336](https://github.com/rapidsai/raft/pull/1329)" (#1336)) [@cjnolet](https://github.com/cjnolet) +- Small follow-up to specializations cleanup ([#1332](https://github.com/rapidsai/raft/pull/1332)) [@cjnolet](https://github.com/cjnolet) +- Fixing select_k specializations ([#1330](https://github.com/rapidsai/raft/pull/1330)) [@cjnolet](https://github.com/cjnolet) +- Fixing remaining bug in ann_quantized ([#1327](https://github.com/rapidsai/raft/pull/1327)) [@cjnolet](https://github.com/cjnolet) +- Fixign a couple small kmeans bugs ([#1274](https://github.com/rapidsai/raft/pull/1274)) [@cjnolet](https://github.com/cjnolet) +- Remove no longer instantiated templates from list of extern template declarations ([#1272](https://github.com/rapidsai/raft/pull/1272)) [@vyasr](https://github.com/vyasr) +- Bump pinned deps to 23.4 ([#1266](https://github.com/rapidsai/raft/pull/1266)) [@vyasr](https://github.com/vyasr) +- Fix the destruction of interruptible token registry ([#1229](https://github.com/rapidsai/raft/pull/1229)) [@achirkin](https://github.com/achirkin) +- Expose raft::handle_t in the public header ([#1192](https://github.com/rapidsai/raft/pull/1192)) [@vyasr](https://github.com/vyasr) +- Fix for svd API ([#1190](https://github.com/rapidsai/raft/pull/1190)) [@lowener](https://github.com/lowener) + +## 📖 Documentation + +- Adding architecture diagram to README.md ([#1370](https://github.com/rapidsai/raft/pull/1370)) [@cjnolet](https://github.com/cjnolet) +- Adding small readme image ([#1354](https://github.com/rapidsai/raft/pull/1354)) [@cjnolet](https://github.com/cjnolet) +- Fix serialize documentation of ivf_flat ([#1347](https://github.com/rapidsai/raft/pull/1347)) [@lowener](https://github.com/lowener) +- Small updates to docs ([#1339](https://github.com/rapidsai/raft/pull/1339)) [@cjnolet](https://github.com/cjnolet) + +## 🚀 New Features + +- Add Options to Generate Build Metrics Report ([#1369](https://github.com/rapidsai/raft/pull/1369)) [@divyegala](https://github.com/divyegala) +- Generic linalg::map ([#1337](https://github.com/rapidsai/raft/pull/1337)) [@achirkin](https://github.com/achirkin) +- Generic linalg::map ([#1329](https://github.com/rapidsai/raft/pull/1329)) [@achirkin](https://github.com/achirkin) +- matrix::select_k specializations ([#1268](https://github.com/rapidsai/raft/pull/1268)) [@achirkin](https://github.com/achirkin) +- Use rapids-cmake new COMPONENT exporting feature ([#1154](https://github.com/rapidsai/raft/pull/1154)) [@robertmaynard](https://github.com/robertmaynard) + +## 🛠️ Improvements + +- Pin `dask` and `distributed` for release ([#1399](https://github.com/rapidsai/raft/pull/1399)) [@galipremsagar](https://github.com/galipremsagar) +- Pin cupy in wheel tests to supported versions ([#1383](https://github.com/rapidsai/raft/pull/1383)) [@vyasr](https://github.com/vyasr) +- CAGRA ([#1375](https://github.com/rapidsai/raft/pull/1375)) [@tfeher](https://github.com/tfeher) +- add a distance epilogue function to the bfknn call ([#1371](https://github.com/rapidsai/raft/pull/1371)) [@benfred](https://github.com/benfred) +- Relax UCX pin to allow 1.14 ([#1366](https://github.com/rapidsai/raft/pull/1366)) [@pentschev](https://github.com/pentschev) +- Generate pyproject dependencies with dfg ([#1364](https://github.com/rapidsai/raft/pull/1364)) [@vyasr](https://github.com/vyasr) +- Add nccl to dependencies.yaml ([#1361](https://github.com/rapidsai/raft/pull/1361)) [@benfred](https://github.com/benfred) +- Add extern template for ivfflat_interleaved_scan ([#1360](https://github.com/rapidsai/raft/pull/1360)) [@ahendriksen](https://github.com/ahendriksen) +- Stop setting package version attribute in wheels ([#1359](https://github.com/rapidsai/raft/pull/1359)) [@vyasr](https://github.com/vyasr) +- Fix ivf flat specialization header IdxT from uint64_t -> int64_t ([#1358](https://github.com/rapidsai/raft/pull/1358)) [@ahendriksen](https://github.com/ahendriksen) +- Remove faiss_mr.hpp ([#1351](https://github.com/rapidsai/raft/pull/1351)) [@benfred](https://github.com/benfred) +- Rename optional helper function ([#1345](https://github.com/rapidsai/raft/pull/1345)) [@viclafargue](https://github.com/viclafargue) +- Pass minimum target compile options through `raft::raft` ([#1341](https://github.com/rapidsai/raft/pull/1341)) [@cjnolet](https://github.com/cjnolet) +- Removing FAISS from build ([#1340](https://github.com/rapidsai/raft/pull/1340)) [@cjnolet](https://github.com/cjnolet) +- Add dispatch based on compute architecture ([#1335](https://github.com/rapidsai/raft/pull/1335)) [@ahendriksen](https://github.com/ahendriksen) +- Consolidate pre-compiled specializations into single `libraft` binary ([#1333](https://github.com/rapidsai/raft/pull/1333)) [@cjnolet](https://github.com/cjnolet) +- Update and standardize IVF indexes API ([#1328](https://github.com/rapidsai/raft/pull/1328)) [@viclafargue](https://github.com/viclafargue) +- Using int64_t specializations for `ivf_pq` and `refine` ([#1325](https://github.com/rapidsai/raft/pull/1325)) [@cjnolet](https://github.com/cjnolet) +- Migrate as much as possible to pyproject.toml ([#1324](https://github.com/rapidsai/raft/pull/1324)) [@vyasr](https://github.com/vyasr) +- Pass `AWS_SESSION_TOKEN` and `SCCACHE_S3_USE_SSL` vars to conda build ([#1321](https://github.com/rapidsai/raft/pull/1321)) [@ajschmidt8](https://github.com/ajschmidt8) +- Numerical stability fixes for l2 pairwise distance ([#1319](https://github.com/rapidsai/raft/pull/1319)) [@benfred](https://github.com/benfred) +- Consolidate linter configuration into pyproject.toml ([#1317](https://github.com/rapidsai/raft/pull/1317)) [@vyasr](https://github.com/vyasr) +- IVF-Flat Python wrappers ([#1316](https://github.com/rapidsai/raft/pull/1316)) [@tfeher](https://github.com/tfeher) +- Add stream overloads to `ivf_pq` serialize/deserialize methods ([#1315](https://github.com/rapidsai/raft/pull/1315)) [@divyegala](https://github.com/divyegala) +- Temporary buffer to view host or device memory in device ([#1313](https://github.com/rapidsai/raft/pull/1313)) [@divyegala](https://github.com/divyegala) +- RAFT skeleton project template ([#1312](https://github.com/rapidsai/raft/pull/1312)) [@cjnolet](https://github.com/cjnolet) +- Fix docs build to be `pydata-sphinx-theme=0.13.0` compatible ([#1311](https://github.com/rapidsai/raft/pull/1311)) [@galipremsagar](https://github.com/galipremsagar) +- Update to GCC 11 ([#1309](https://github.com/rapidsai/raft/pull/1309)) [@bdice](https://github.com/bdice) +- Reduce compile times of distance specializations ([#1307](https://github.com/rapidsai/raft/pull/1307)) [@ahendriksen](https://github.com/ahendriksen) +- Fix docs upload path ([#1305](https://github.com/rapidsai/raft/pull/1305)) [@AyodeAwe](https://github.com/AyodeAwe) +- Add end-to-end CUDA ann-benchmarks to raft ([#1304](https://github.com/rapidsai/raft/pull/1304)) [@cjnolet](https://github.com/cjnolet) +- Make docs builds less verbose ([#1302](https://github.com/rapidsai/raft/pull/1302)) [@AyodeAwe](https://github.com/AyodeAwe) +- Stop using versioneer to manage versions ([#1301](https://github.com/rapidsai/raft/pull/1301)) [@vyasr](https://github.com/vyasr) +- Adding util to get the device id for a pointer address ([#1297](https://github.com/rapidsai/raft/pull/1297)) [@cjnolet](https://github.com/cjnolet) +- Enable dfg in pre-commit. ([#1293](https://github.com/rapidsai/raft/pull/1293)) [@vyasr](https://github.com/vyasr) +- Python API for brute-force KNN ([#1292](https://github.com/rapidsai/raft/pull/1292)) [@cjnolet](https://github.com/cjnolet) +- support k up to 2048 in faiss select ([#1287](https://github.com/rapidsai/raft/pull/1287)) [@benfred](https://github.com/benfred) +- CI: Remove specification of manual stage for check_style.sh script. ([#1283](https://github.com/rapidsai/raft/pull/1283)) [@csadorf](https://github.com/csadorf) +- New Sparse Matrix APIs ([#1279](https://github.com/rapidsai/raft/pull/1279)) [@cjnolet](https://github.com/cjnolet) +- fix build on cuda 11.5 ([#1277](https://github.com/rapidsai/raft/pull/1277)) [@benfred](https://github.com/benfred) +- IVF-Flat index splitting ([#1271](https://github.com/rapidsai/raft/pull/1271)) [@lowener](https://github.com/lowener) +- Remove duplicate `librmm` runtime dependency ([#1264](https://github.com/rapidsai/raft/pull/1264)) [@ajschmidt8](https://github.com/ajschmidt8) +- build.sh: Add option to log nvcc compile times ([#1262](https://github.com/rapidsai/raft/pull/1262)) [@ahendriksen](https://github.com/ahendriksen) +- Reduce error handling verbosity in CI tests scripts ([#1259](https://github.com/rapidsai/raft/pull/1259)) [@AjayThorve](https://github.com/AjayThorve) +- Update shared workflow branches ([#1256](https://github.com/rapidsai/raft/pull/1256)) [@ajschmidt8](https://github.com/ajschmidt8) +- Keeping only compute similarity specializations for uint64_t for now ([#1255](https://github.com/rapidsai/raft/pull/1255)) [@cjnolet](https://github.com/cjnolet) +- Fix compile time explosion for minkowski distance ([#1254](https://github.com/rapidsai/raft/pull/1254)) [@ahendriksen](https://github.com/ahendriksen) +- Unpin `dask` and `distributed` for development ([#1253](https://github.com/rapidsai/raft/pull/1253)) [@galipremsagar](https://github.com/galipremsagar) +- Remove gpuCI scripts. ([#1252](https://github.com/rapidsai/raft/pull/1252)) [@bdice](https://github.com/bdice) +- IVF-PQ: store cluster data in individual lists and reduce templates ([#1249](https://github.com/rapidsai/raft/pull/1249)) [@achirkin](https://github.com/achirkin) +- Fix inconsistency between the building doc and CMakeLists.txt ([#1248](https://github.com/rapidsai/raft/pull/1248)) [@yong-wang](https://github.com/yong-wang) +- Consolidating ANN benchmarks and tests ([#1243](https://github.com/rapidsai/raft/pull/1243)) [@cjnolet](https://github.com/cjnolet) +- mdspan view for IVF-PQ API ([#1236](https://github.com/rapidsai/raft/pull/1236)) [@viclafargue](https://github.com/viclafargue) +- Remove uint32 distance idx specializations ([#1235](https://github.com/rapidsai/raft/pull/1235)) [@cjnolet](https://github.com/cjnolet) +- Add innerproduct to the pairwise distance api ([#1226](https://github.com/rapidsai/raft/pull/1226)) [@benfred](https://github.com/benfred) +- Move date to build string in `conda` recipe ([#1223](https://github.com/rapidsai/raft/pull/1223)) [@ajschmidt8](https://github.com/ajschmidt8) +- Replace faiss bfKnn ([#1202](https://github.com/rapidsai/raft/pull/1202)) [@benfred](https://github.com/benfred) +- Expose KMeans `init_plus_plus` in pylibraft ([#1198](https://github.com/rapidsai/raft/pull/1198)) [@betatim](https://github.com/betatim) +- Fix `ucx-py` version ([#1184](https://github.com/rapidsai/raft/pull/1184)) [@ajschmidt8](https://github.com/ajschmidt8) +- Improve the performance of radix top-k ([#1175](https://github.com/rapidsai/raft/pull/1175)) [@yong-wang](https://github.com/yong-wang) +- Add docs build job ([#1168](https://github.com/rapidsai/raft/pull/1168)) [@AyodeAwe](https://github.com/AyodeAwe) +- Remove deprecated headers ([#1145](https://github.com/rapidsai/raft/pull/1145)) [@lowener](https://github.com/lowener) +- Simplify distance/detail to make is easier to dispatch to different kernel implementations ([#1142](https://github.com/rapidsai/raft/pull/1142)) [@ahendriksen](https://github.com/ahendriksen) +- Initial port of auto-find-k ([#1070](https://github.com/rapidsai/raft/pull/1070)) [@cjnolet](https://github.com/cjnolet) + # raft 23.02.00 (9 Feb 2023) ## 🚨 Breaking Changes diff --git a/README.md b/README.md index b77e906262..10cd7b16fc 100755 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32) output = pairwise_distance(in1, in2, metric="euclidean") ``` -The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, and PyTorch that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow. +The `output` array in the above example is of type `raft.common.device_ndarray`, which supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) making it interoperable with other libraries like CuPy, Numba, PyTorch and RAPIDS cuDF that also support it. CuPy supports DLPack, which also enables zero-copy conversion from `raft.common.device_ndarray` to JAX and Tensorflow. Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array: ```python @@ -160,6 +160,11 @@ import torch torch_tensor = torch.as_tensor(output, device='cuda') ``` +Or converting to a RAPIDS cuDF dataframe: +```python +cudf_dataframe = cudf.DataFrame(output) +``` + When the corresponding library has been installed and available in your environment, this conversion can also be done automatically by all RAFT compute APIs by setting a global configuration option: ```python import pylibraft.config @@ -198,7 +203,7 @@ RAFT itself can be installed through conda, [CMake Package Manager (CPM)](https: The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library of pre-compiled template specializations and runtime APIs. +- `libraft` (optional) shared library of pre-compiled template instantiations and runtime APIs. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -231,11 +236,11 @@ You can find an [example RAFT](cpp/template/README.md) project template in the ` Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. -| Component | Target | Description | Base Dependencies | -|-------------|---------------------|-----------------------------------------------------------|---------------------------------------| -| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | -| compiled | `raft::compiled` | Pre-compiled template specializations and runtime library | raft::raft | -| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | +| Component | Target | Description | Base Dependencies | +|-------------|---------------------|----------------------------------------------------------|----------------------------------------| +| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | +| compiled | `raft::compiled` | Pre-compiled template instantiations and runtime library | raft::raft | +| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | ### Source @@ -282,7 +287,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `util`: Various reusable tools and utilities for accelerated algorithm development - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - - `src`: Compiled APIs and template specializations for the shared libraries + - `src`: Compiled APIs and template instantiations for the shared libraries - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 5db6fa11be..e52beb22ea 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -19,7 +19,7 @@ rapids-print-env rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python) -VERSION_NUMBER=$(rapids-get-rapids-version-from-git) +VERSION_NUMBER="23.06" rapids-mamba-retry install \ --channel "${CPP_CHANNEL}" \ diff --git a/ci/release/apply_wheel_modifications.sh b/ci/release/apply_wheel_modifications.sh index efc8f0c77c..fd6c2f929e 100755 --- a/ci/release/apply_wheel_modifications.sh +++ b/ci/release/apply_wheel_modifications.sh @@ -18,3 +18,8 @@ sed -i "s/rmm/rmm${CUDA_SUFFIX}/g" python/pylibraft/pyproject.toml sed -i "s/^name = \"raft-dask\"/name = \"raft-dask${CUDA_SUFFIX}\"/g" python/raft-dask/pyproject.toml sed -i "s/pylibraft/pylibraft${CUDA_SUFFIX}/g" python/raft-dask/pyproject.toml sed -i "s/ucx-py/ucx-py${CUDA_SUFFIX}/g" python/raft-dask/pyproject.toml + +if [[ $CUDA_SUFFIX == "-cu12" ]]; then + sed -i "s/cuda-python[<=>\.,0-9]*/cuda-python>=12.0,<13.0/g" python/pylibraft/pyproject.toml + sed -i "s/cupy-cuda11x/cupy-cuda12x/g" python/pylibraft/pyproject.toml +fi diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index d8c22b4931..f6c6b08644 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -80,6 +80,7 @@ sed_runner "s/ucx-py.*\",/ucx-py==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\",/g" python for FILE in .github/workflows/*.yaml; do sed_runner "/shared-action-workflows/ s/@.*/@branch-${NEXT_SHORT_TAG}/g" "${FILE}" done +sed_runner "s/VERSION_NUMBER=\".*/VERSION_NUMBER=\"${NEXT_SHORT_TAG}\"/g" ci/build_docs.sh sed_runner "/^PROJECT_NUMBER/ s|\".*\"|\"${NEXT_SHORT_TAG}\"|g" cpp/doxygen/Doxyfile diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 0e06076f1a..aae2aa3d15 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -9,13 +9,13 @@ channels: dependencies: - breathe - c-compiler -- clang-tools=11.1.0 -- clang=11.1.0 +- clang-tools=16.0.1 +- clang=16.0.1 - cmake>=3.23.1,!=3.25.0 - cuda-profiler-api=11.8.86 -- cuda-python >=11.7.1,<12.0 +- cuda-python>=11.7.1,<12.0 - cudatoolkit=11.8 -- cupy +- cupy>=12.0.0 - cxx-compiler - cython>=0.29,<0.30 - dask-core==2023.3.2 @@ -24,7 +24,9 @@ dependencies: - distributed==2023.3.2.1 - doxygen>=1.8.20 - gcc_linux-64=11.* +- gmock>=1.13.0 - graphviz +- gtest>=1.13.0 - ipython - joblib>=0.11 - libcublas-dev=11.11.3.6 @@ -45,7 +47,7 @@ dependencies: - pytest-cov - recommonmark - rmm==23.6.* -- scikit-build>=0.13.1 +- scikit-build>=0.13.1,<0.17.2 - scikit-learn - scipy - sphinx-copybutton diff --git a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml index 5965aaef8f..3ea560025e 100644 --- a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml @@ -8,8 +8,8 @@ channels: - nvidia dependencies: - c-compiler -- clang-tools=11.1.0 -- clang=11.1.0 +- clang-tools=16.0.1 +- clang=16.0.1 - cmake>=3.23.1,!=3.25.0 - cuda-profiler-api=11.8.86 - cudatoolkit=11.8 @@ -32,6 +32,6 @@ dependencies: - nccl>=2.9.9 - ninja - nlohmann_json>=3.11.2 -- scikit-build>=0.13.1 +- scikit-build>=0.13.1,<0.17.2 - sysroot_linux-64==2.17 name: bench_ann_cuda-118_arch-x86_64 diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index 2a66f213a7..bec773d26d 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -17,7 +17,7 @@ nccl_version: - ">=2.9.9" gtest_version: - - "=1.10.0" + - ">=1.13.0" glog_version: - ">=0.6.0" diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 8ec9cc10c6..b89fcfb788 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -36,6 +36,7 @@ outputs: - SCCACHE_S3_KEY_PREFIX=libraft-aarch64 # [aarch64] - SCCACHE_S3_KEY_PREFIX=libraft-linux64 # [linux64] - SCCACHE_S3_USE_SSL + - SCCACHE_S3_NO_CREDENTIALS number: {{ GIT_DESCRIBE_NUMBER }} string: cuda{{ cuda_major }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} ignore_run_exports_from: @@ -52,6 +53,9 @@ outputs: host: - librmm ={{ minor_version }} - cudatoolkit {{ cuda_version }} + run: + - {{ pin_compatible('cudatoolkit', max_pin='x', min_pin='x') }} + - librmm ={{ minor_version }} about: home: https://rapids.ai/ license: Apache-2.0 @@ -68,7 +72,6 @@ outputs: run: - {{ pin_subpackage('libraft-headers-only', exact=True) }} - cuda-profiler-api {{ cuda_profiler_api_run_version }} - - cudatoolkit {{ cuda_version }} - librmm ={{ minor_version }} - libcublas {{ libcublas_run_version }} - libcublas-dev {{ libcublas_run_version }} @@ -101,6 +104,7 @@ outputs: - sysroot_{{ target_platform }} {{ sysroot_version }} host: - {{ pin_subpackage('libraft-headers', exact=True) }} + - cudatoolkit {{ cuda_version }} - cuda-profiler-api {{ cuda_profiler_api_host_version }} - libcublas {{ libcublas_host_version }} - libcublas-dev {{ libcublas_host_version }} @@ -135,6 +139,7 @@ outputs: - sysroot_{{ target_platform }} {{ sysroot_version }} host: - {{ pin_subpackage('libraft', exact=True) }} + - cudatoolkit {{ cuda_version }} - cuda-profiler-api {{ cuda_profiler_api_host_version }} - gmock {{ gtest_version }} - gtest {{ gtest_version }} @@ -200,6 +205,7 @@ outputs: - sysroot_{{ target_platform }} {{ sysroot_version }} host: - {{ pin_subpackage('libraft', exact=True) }} + - cudatoolkit {{ cuda_version }} - libcublas {{ libcublas_host_version }} - libcublas-dev {{ libcublas_host_version }} - glog {{ glog_version }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6461492169..cddfa4b38d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -70,13 +70,11 @@ option(RAFT_COMPILE_LIBRARY "Enable building raft shared library instantiations" ${RAFT_COMPILE_LIBRARY_DEFAULT} ) - -# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs -# to have different values for the `Threads::Threads` target. Setting this flag ensures +# Needed because GoogleBenchmark changes the state of FindThreads.cmake, causing subsequent runs to +# have different values for the `Threads::Threads` target. Setting this flag ensures # `Threads::Threads` is the same value across all builds so that cache hits occur set(THREADS_PREFER_PTHREAD_FLAG ON) - include(CMakeDependentOption) # cmake_dependent_option( RAFT_USE_FAISS_STATIC "Build and statically link the FAISS library for # nearest neighbors search on GPU" ON RAFT_COMPILE_LIBRARY OFF ) @@ -265,180 +263,135 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib - src/distance/pairwise_distance.cu - src/distance/fused_l2_min_arg.cu - src/cluster/update_centroids_float.cu - src/cluster/update_centroids_double.cu - src/cluster/cluster_cost_float.cu - src/cluster/cluster_cost_double.cu - src/neighbors/refine_d_int64_t_float.cu - src/neighbors/refine_d_int64_t_int8_t.cu - src/neighbors/refine_d_int64_t_uint8_t.cu - src/neighbors/refine_h_int64_t_float.cu - src/neighbors/refine_h_int64_t_int8_t.cu - src/neighbors/refine_h_int64_t_uint8_t.cu - src/neighbors/specializations/refine_d_int64_t_float.cu - src/neighbors/specializations/refine_d_int64_t_int8_t.cu - src/neighbors/specializations/refine_d_int64_t_uint8_t.cu - src/neighbors/specializations/refine_h_int64_t_float.cu - src/neighbors/specializations/refine_h_int64_t_int8_t.cu - src/neighbors/specializations/refine_h_int64_t_uint8_t.cu - src/cluster/kmeans_fit_float.cu - src/cluster/kmeans_fit_double.cu - src/cluster/kmeans_init_plus_plus_double.cu - src/cluster/kmeans_init_plus_plus_float.cu - src/distance/specializations/detail/canberra_double_double_double_int.cu - src/distance/specializations/detail/canberra_float_float_float_int.cu - src/distance/specializations/detail/correlation_double_double_double_int.cu - src/distance/specializations/detail/correlation_float_float_float_int.cu - src/distance/specializations/detail/cosine_double_double_double_int.cu - src/distance/specializations/detail/cosine_float_float_float_int.cu - src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu - src/distance/specializations/detail/inner_product_float_float_float_int.cu - src/distance/specializations/detail/inner_product_double_double_double_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kernels/gram_matrix_base_double.cu - src/distance/specializations/detail/kernels/gram_matrix_base_float.cu - src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu - src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu - # These are somehow missing a kernel definition which is causing a compile error. - # src/distance/specializations/detail/kernels/rbf_kernel_double.cu - # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/neighbors/brute_force_knn_int64_t_float.cu - src/distance/specializations/detail/kernels/tanh_kernel_double.cu - src/distance/specializations/detail/kernels/tanh_kernel_float.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_double_double_double_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l_inf_double_double_double_int.cu - src/distance/specializations/detail/l_inf_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/russel_rao_double_double_double_int.cu - src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu - src/matrix/specializations/detail/select_k_float_uint32_t.cu - src/matrix/specializations/detail/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_half_uint32_t.cu - src/matrix/specializations/detail/select_k_half_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu + src/core/logger.cpp + src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/distance.cu + src/distance/fused_l2_nn.cu + src/linalg/detail/coalesced_reduction.cu + src/matrix/detail/select_k_double_int64_t.cu + src/matrix/detail/select_k_double_uint32_t.cu + src/matrix/detail/select_k_float_int64_t.cu + src/matrix/detail/select_k_float_uint32_t.cu + src/matrix/detail/select_k_half_int64_t.cu + src/matrix/detail/select_k_half_uint32_t.cu + src/neighbors/ball_cover.cu + src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu + src/neighbors/brute_force_knn_int_float_int.cu + src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu + src/neighbors/detail/ivf_flat_search.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu + src/neighbors/detail/selection_faiss_int32_t_float.cu + src/neighbors/detail/selection_faiss_int_double.cu + src/neighbors/detail/selection_faiss_long_float.cu + src/neighbors/detail/selection_faiss_size_t_double.cu + src/neighbors/detail/selection_faiss_size_t_float.cu + src/neighbors/detail/selection_faiss_uint32_t_float.cu + src/neighbors/ivf_flat_build_float_int64_t.cu + src/neighbors/ivf_flat_build_int8_t_int64_t.cu + src/neighbors/ivf_flat_build_uint8_t_int64_t.cu + src/neighbors/ivf_flat_extend_float_int64_t.cu + src/neighbors/ivf_flat_extend_int8_t_int64_t.cu + src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu + src/neighbors/ivf_flat_search_float_int64_t.cu + src/neighbors/ivf_flat_search_int8_t_int64_t.cu + src/neighbors/ivf_flat_search_uint8_t_int64_t.cu + src/neighbors/ivfpq_build_float_int64_t.cu + src/neighbors/ivfpq_build_int8_t_int64_t.cu + src/neighbors/ivfpq_build_uint8_t_int64_t.cu + src/neighbors/ivfpq_extend_float_int64_t.cu + src/neighbors/ivfpq_extend_int8_t_int64_t.cu + src/neighbors/ivfpq_extend_uint8_t_int64_t.cu src/neighbors/ivfpq_search_float_int64_t.cu src/neighbors/ivfpq_search_int8_t_int64_t.cu src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu - src/neighbors/specializations/ball_cover_all_knn_query.cu - src/neighbors/specializations/ball_cover_build_index.cu - src/neighbors/specializations/ball_cover_knn_query.cu - src/neighbors/specializations/fused_l2_knn_long_float_true.cu - src/neighbors/specializations/fused_l2_knn_long_float_false.cu - src/neighbors/specializations/fused_l2_knn_int_float_true.cu - src/neighbors/specializations/fused_l2_knn_int_float_false.cu - src/neighbors/ivf_flat_search.cu - src/neighbors/ivf_flat_build.cu - src/neighbors/specializations/ivfflat_build_float_int64_t.cu - src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_float_int64_t.cu - src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_float_int64_t.cu - src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu + src/neighbors/refine_float_float.cu + src/neighbors/refine_int8_t_float.cu + src/neighbors/refine_uint8_t_float.cu + src/raft_runtime/cluster/cluster_cost.cuh + src/raft_runtime/cluster/cluster_cost_double.cu + src/raft_runtime/cluster/cluster_cost_float.cu + src/raft_runtime/cluster/kmeans_fit_double.cu + src/raft_runtime/cluster/kmeans_fit_float.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu + src/raft_runtime/cluster/update_centroids.cuh + src/raft_runtime/cluster/update_centroids_double.cu + src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_l2_min_arg.cu + src/raft_runtime/distance/pairwise_distance.cu + src/raft_runtime/matrix/select_k_float_int64_t.cu + src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu + src/raft_runtime/neighbors/ivf_flat_build.cu + src/raft_runtime/neighbors/ivf_flat_search.cu + src/raft_runtime/neighbors/ivfpq_build.cu + src/raft_runtime/neighbors/ivfpq_deserialize.cu + src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_serialize.cu + src/raft_runtime/neighbors/refine_d_int64_t_float.cu + src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_float.cu + src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu + src/raft_runtime/random/rmat_rectangular_generator_int_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu + src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu + src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu + src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu + src/util/memory_pool.cpp ) set_target_properties( raft_lib @@ -464,7 +417,13 @@ if(RAFT_COMPILE_LIBRARY) raft_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" ) - target_compile_definitions(raft_lib INTERFACE "RAFT_COMPILED") + + # RAFT_COMPILED is set during compilation of libraft.so as well as downstream libraries (due to + # "PUBLIC") + target_compile_definitions(raft_lib PUBLIC "RAFT_COMPILED") + + # RAFT_EXPLICIT_INSTANTIATE_ONLY is set during compilation of libraft.so (due to "PRIVATE") + target_compile_definitions(raft_lib PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp index b4d8fbeee3..c34b95010f 100644 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ b/cpp/bench/ann/src/common/benchmark.hpp @@ -14,7 +14,7 @@ * limitations under the License. */ #ifdef NVTX -#include +#include #endif #include diff --git a/cpp/bench/ann/src/common/dataset.h b/cpp/bench/ann/src/common/dataset.h index 1244935c99..46dd66d649 100644 --- a/cpp/bench/ann/src/common/dataset.h +++ b/cpp/bench/ann/src/common/dataset.h @@ -47,7 +47,7 @@ class BinFile { uint32_t subset_first_row = 0, uint32_t subset_size = 0); ~BinFile() { fclose(fp_); } - BinFile(const BinFile&) = delete; + BinFile(const BinFile&) = delete; BinFile& operator=(const BinFile&) = delete; void get_shape(size_t* nrows, int* ndims) @@ -219,7 +219,7 @@ class Dataset { Dataset(const std::string& name, const std::string& distance) : name_(name), distance_(distance) { } - Dataset(const Dataset&) = delete; + Dataset(const Dataset&) = delete; Dataset& operator=(const Dataset&) = delete; virtual ~Dataset(); diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index d8e98ce2a9..baff1b1c45 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -22,10 +22,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include "../common/ann_types.hpp" #include "../common/benchmark_util.hpp" #undef WARP_SIZE diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ivf_flat.cu index ff108080b5..bcd23723a4 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_flat.cu @@ -15,12 +15,8 @@ */ #include "raft_ivf_flat_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; -} // namespace raft::bench::ann \ No newline at end of file +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 8b2a7d329b..0a80eef1b5 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ivf_pq.cu index 338bc9a32f..2efe14631b 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_pq.cu @@ -15,10 +15,6 @@ */ #include "raft_ivf_pq_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfPQ; template class RaftIvfPQ; diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 70dff81847..517272e6cf 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -42,6 +42,7 @@ template class RaftIvfPQ : public ANN { public: using typename ANN::AnnSearchParam; + using ANN::dim_; struct SearchParam : public AnnSearchParam { raft::neighbors::ivf_pq::search_params pq_param; @@ -118,7 +119,7 @@ void RaftIvfPQ::load(const std::string& file) template void RaftIvfPQ::build(const T* dataset, size_t nrow, cudaStream_t) { - auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), index_->dim()); + auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), dim_); index_.emplace(raft::runtime::neighbors::ivf_pq::build(handle_, index_params_, dataset_v)); return; diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index f6499623dd..505ca32886 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureBench) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -55,6 +55,10 @@ function(ConfigureBench) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${BENCH_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories( ${BENCH_NAME} PUBLIC "$" ) @@ -71,7 +75,7 @@ endfunction() if(BUILD_PRIMS_BENCH) ConfigureBench( NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu - bench/prims/main.cpp OPTIONAL LIB + bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -93,6 +97,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -112,7 +117,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -139,5 +144,6 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) endif() diff --git a/cpp/bench/prims/cluster/kmeans.cu b/cpp/bench/prims/cluster/kmeans.cu index af7afb8037..3147960f72 100644 --- a/cpp/bench/prims/cluster/kmeans.cu +++ b/cpp/bench/prims/cluster/kmeans.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBenchParams { diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index 6bda43bdb2..42a8f7967c 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBalancedBenchParams { diff --git a/cpp/bench/prims/common/benchmark.hpp b/cpp/bench/prims/common/benchmark.hpp index 4b6e1ba286..1e783eb338 100644 --- a/cpp/bench/prims/common/benchmark.hpp +++ b/cpp/bench/prims/common/benchmark.hpp @@ -113,8 +113,19 @@ class fixture { raft::device_resources handle; rmm::cuda_stream_view stream; - fixture() : stream{handle.get_stream()} + fixture(bool use_pool_memory_resource = false) : stream{handle.get_stream()} { + // Cache memory pool between test runs, since it is expensive to create. + // This speeds up the time required to run the select_k bench by over 3x. + // This is part of the fixture class here so that the pool will get cleaned + // up, rather than outliving the benchmarks that require it. + static std::unique_ptr memory_pool; + if (use_pool_memory_resource) { + if (!memory_pool) { memory_pool.reset(new using_pool_memory_res()); } + } else if (memory_pool) { + memory_pool.reset(); + } + int l2_cache_size = 0; int device_id = 0; RAFT_CUDA_TRY(cudaGetDevice(&device_id)); diff --git a/cpp/bench/prims/distance/distance_common.cuh b/cpp/bench/prims/distance/distance_common.cuh index 9b5d67a46f..dff3401b62 100644 --- a/cpp/bench/prims/distance/distance_common.cuh +++ b/cpp/bench/prims/distance/distance_common.cuh @@ -17,9 +17,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index 1c45572782..24c0cbf8f9 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -16,10 +16,8 @@ #include #include +#include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 4407bdcf83..53d97c1fc7 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -13,10 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu index f9f234187d..c804ecb3a1 100644 --- a/cpp/bench/prims/distance/masked_nn.cu +++ b/cpp/bench/prims/distance/masked_nn.cu @@ -30,10 +30,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 870119db52..d0bc993cc1 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -23,10 +23,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include @@ -46,7 +42,8 @@ using namespace raft::bench; // NOLINT template struct selection : public fixture { explicit selection(const select::params& p) - : params_(p), + : fixture(true), + params_(p), in_dists_(p.batch_size * p.len, stream), in_ids_(p.batch_size * p.len, stream), out_dists_(p.batch_size * p.k, stream), @@ -76,7 +73,6 @@ struct selection : public fixture { void run_benchmark(::benchmark::State& state) override // NOLINT { device_resources handle{stream}; - using_pool_memory_res res; try { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; @@ -157,34 +153,33 @@ const std::vector kInputs{ {10, 1000000, 256, true, false, true}, }; -#define SELECTION_REGISTER(KeyT, IdxT, A) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 8f0b1cb5d9..8239fa4f89 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include @@ -226,7 +222,8 @@ struct brute_force_knn { template struct knn : public fixture { explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope) - : params_(p), + : fixture(true), + params_(p), strategy_(strategy), scope_(scope), dev_mem_res_(strategy == TransferStrategy::MANAGED), @@ -278,8 +275,6 @@ struct knn : public fixture { "device (TransferStrategy::NO_COPY)"); } - using_pool_memory_res default_resource; - try { std::ostringstream label_stream; label_stream << params_ << "#" << strategy_ << "#" << scope_; @@ -384,11 +379,10 @@ inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY inline const std::vector kScopeFull{Scope::BUILD_SEARCH}; inline const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; -#define KNN_REGISTER(ValT, IdxT, ImplT, inputs, strats, scope) \ - namespace BENCHMARK_PRIVATE_NAME(knn) \ - { \ - using KNN = knn>; \ - RAFT_BENCH_REGISTER(KNN, #ValT "/" #IdxT "/" #ImplT, inputs, strats, scope); \ +#define KNN_REGISTER(ValT, IdxT, ImplT, inputs, strats, scope) \ + namespace BENCHMARK_PRIVATE_NAME(knn) { \ + using KNN = knn>; \ + RAFT_BENCH_REGISTER(KNN, #ValT "/" #IdxT "/" #ImplT, inputs, strats, scope); \ } } // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/refine_float_int64_t.cu b/cpp/bench/prims/neighbors/refine_float_int64_t.cu index 43be330e9b..bbedc1ae64 100644 --- a/cpp/bench/prims/neighbors/refine_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_float_int64_t.cu @@ -17,11 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu index 1d7cb8c8aa..4952361f03 100644 --- a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu @@ -17,10 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/cmake/modules/ConfigureCUDA.cmake b/cpp/cmake/modules/ConfigureCUDA.cmake index c733d46985..ea8a077b0c 100644 --- a/cpp/cmake/modules/ConfigureCUDA.cmake +++ b/cpp/cmake/modules/ConfigureCUDA.cmake @@ -17,8 +17,16 @@ if(DISABLE_DEPRECATION_WARNINGS) list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wno-deprecated-declarations) endif() +# Be very strict when compiling with GCC as host compiler (and thus more lenient when compiling with +# clang) if(CMAKE_COMPILER_IS_GNUCXX) list(APPEND RAFT_CXX_FLAGS -Wall -Werror -Wno-unknown-pragmas -Wno-error=deprecated-declarations) + list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations) + + # set warnings as errors + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0) + list(APPEND RAFT_CUDA_FLAGS -Werror=all-warnings) + endif() endif() if(CUDA_LOG_COMPILE_TIME) @@ -31,12 +39,6 @@ list(APPEND RAFT_CUDA_FLAGS "-DCUDA_API_PER_THREAD_DEFAULT_STREAM") # make sure we produce smallest binary size list(APPEND RAFT_CUDA_FLAGS -Xfatbin=-compress-all) -# set warnings as errors -if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.2.0) - list(APPEND RAFT_CUDA_FLAGS -Werror=all-warnings) -endif() -list(APPEND RAFT_CUDA_FLAGS -Xcompiler=-Wall,-Werror,-Wno-error=deprecated-declarations) - # Option to enable line info in CUDA device compilation to allow introspection when profiling / # memchecking if(CUDA_ENABLE_LINEINFO) diff --git a/cpp/cmake/thirdparty/get_glog.cmake b/cpp/cmake/thirdparty/get_glog.cmake index 9334224de5..35a9170f99 100644 --- a/cpp/cmake/thirdparty/get_glog.cmake +++ b/cpp/cmake/thirdparty/get_glog.cmake @@ -26,7 +26,6 @@ function(find_and_configure_glog) CPM_ARGS GIT_REPOSITORY https://github.com/${PKG_FORK}/glog.git GIT_TAG ${PKG_PINNED_TAG} - SOURCE_SUBDIR cpp EXCLUDE_FROM_ALL ${PKG_EXCLUDE_FROM_ALL} ) @@ -46,4 +45,4 @@ find_and_configure_glog(VERSION 0.6.0 FORK google PINNED_TAG v0.6.0 EXCLUDE_FROM_ALL ON - ) \ No newline at end of file + ) diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 17a1e0caca..1948169c91 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -918,6 +918,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* +# TODO: remove specializations from exclude patterns when headers have been removed. EXCLUDE_PATTERNS = */detail/* \ */specializations/* \ */thirdparty/* diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh index 3d23c809c3..eb89ebe402 100644 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -436,7 +436,7 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY)) adjust_centers_kernel(MathT* centers, // [n_clusters, dim] IdxT n_clusters, IdxT dim, - const T* dataset, // [n_rows, dim] + const T* dataset, // [n_rows, dim] IdxT n_rows, const LabelT* labels, // [n_rows] const CounterT* cluster_sizes, // [n_clusters] @@ -976,7 +976,7 @@ void build_hierarchical(const raft::device_resources& handle, raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size)); if (pool_guard) { RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); + mem_per_row * size_t(max_minibatch_size)); } // Precompute the L2 norm of the dataset if relevant. diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 76fc22e99e..cca1cbb6e9 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp index 9a4fcfef60..cd815622bf 100644 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ b/cpp/include/raft/cluster/single_linkage_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -77,9 +77,7 @@ class linkage_output { } }; -class linkage_output_int : public linkage_output { -}; -class linkage_output_int64 : public linkage_output { -}; +class linkage_output_int : public linkage_output {}; +class linkage_output_int64 : public linkage_output {}; }; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh index 9b68d7adc9..9588a7f329 100644 --- a/cpp/include/raft/cluster/specializations.cuh +++ b/cpp/include/raft/cluster/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __CLUSTER_SPECIALIZATIONS_H -#define __CLUSTER_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/common/cub_wrappers.cuh b/cpp/include/raft/common/cub_wrappers.cuh index e80d7cccd9..dd8fc2d103 100644 --- a/cpp/include/raft/common/cub_wrappers.cuh +++ b/cpp/include/raft/common/cub_wrappers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please note that there is no equivalent in RAFT's public API" +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please note that there is no equivalent in RAFT's public API" " so this file will eventually be removed altogether.") #include diff --git a/cpp/include/raft/common/device_loads_stores.cuh b/cpp/include/raft/common/device_loads_stores.cuh index f3cfbd81cc..6c62cd70cc 100644 --- a/cpp/include/raft/common/device_loads_stores.cuh +++ b/cpp/include/raft/common/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the raft/util version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the raft/util version instead.") #include diff --git a/cpp/include/raft/common/scatter.cuh b/cpp/include/raft/common/scatter.cuh index 0e83f9a5cd..72de79a596 100644 --- a/cpp/include/raft/common/scatter.cuh +++ b/cpp/include/raft/common/scatter.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the raft/matrix version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the raft/matrix version instead.") #include diff --git a/cpp/include/raft/common/seive.hpp b/cpp/include/raft/common/seive.hpp index 633c8dd3e1..433b032b0f 100644 --- a/cpp/include/raft/common/seive.hpp +++ b/cpp/include/raft/common/seive.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the raft/util version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the raft/util version instead.") #include diff --git a/cpp/include/raft/core/coo_matrix.hpp b/cpp/include/raft/core/coo_matrix.hpp index efab8a1601..a5f7c05493 100644 --- a/cpp/include/raft/core/coo_matrix.hpp +++ b/cpp/include/raft/core/coo_matrix.hpp @@ -71,12 +71,6 @@ class coordinate_structure_view { } - /** - * Create a view from this view. Note that this is for interface compatibility - * @return - */ - view_type view() { return view_type(rows_, cols_, this->get_n_rows(), this->get_n_cols()); } - /** * Return span containing underlying rows array * @return span containing underlying rows array @@ -209,6 +203,10 @@ class coo_matrix_view coordinate_structure_view, is_device> { public: + using element_type = ElementType; + using row_type = RowType; + using col_type = ColType; + using nnz_type = NZType; coo_matrix_view(raft::span element_span, coordinate_structure_view structure_view) : sparse_matrix_view { public: using element_type = ElementType; + using row_type = RowType; + using col_type = ColType; + using nnz_type = NZType; using structure_view_type = typename structure_type::view_type; using container_type = typename ContainerPolicy::container_type; using sparse_matrix_type = @@ -258,14 +259,9 @@ class coo_matrix // Constructor that owns the data but not the structure template > - coo_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + coo_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) : sparse_matrix_type(handle, structure){}; - /** - * Return a view of the structure underlying this matrix - * @return - */ - structure_view_type structure_view() { return this->structure_.get()->view(); } /** * Initialize the sparsity on this instance if it was not known upon construction @@ -277,7 +273,20 @@ class coo_matrix void initialize_sparsity(NZType nnz) { sparse_matrix_type::initialize_sparsity(nnz); - this->structure_.get()->initialize_sparsity(nnz); + this->structure_.initialize_sparsity(nnz); + } + + /** + * Return a view of the structure underlying this matrix + * @return + */ + structure_view_type structure_view() + { + if constexpr (get_sparsity_type() == SparsityType::OWNING) { + return this->structure_.view(); + } else { + return this->structure_; + } } }; } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/csr_matrix.hpp b/cpp/include/raft/core/csr_matrix.hpp index fac656b3f9..95d09d3eea 100644 --- a/cpp/include/raft/core/csr_matrix.hpp +++ b/cpp/include/raft/core/csr_matrix.hpp @@ -87,12 +87,6 @@ class compressed_structure_view */ span get_indices() override { return indices_; } - /** - * Create a view from this view. Note that this is for interface compatibility - * @return - */ - view_type view() { return view_type(indptr_, indices_, this->get_n_cols()); } - protected: raft::span indptr_; raft::span indices_; @@ -147,7 +141,7 @@ class compressed_structure constexpr auto operator=(compressed_structure const&) noexcept( std::is_nothrow_copy_assignable::value) -> compressed_structure& = default; - constexpr auto operator =(compressed_structure&&) noexcept( + constexpr auto operator=(compressed_structure&&) noexcept( std::is_nothrow_move_assignable::value) -> compressed_structure& = default; @@ -221,6 +215,10 @@ class csr_matrix_view compressed_structure_view, is_device> { public: + using element_type = ElementType; + using indptr_type = IndptrType; + using indices_type = IndicesType; + using nnz_type = NZType; csr_matrix_view( raft::span element_span, compressed_structure_view structure_view) @@ -249,6 +247,9 @@ class csr_matrix ContainerPolicy> { public: using element_type = ElementType; + using indptr_type = IndptrType; + using indices_type = IndicesType; + using nnz_type = NZType; using structure_view_type = typename structure_type::view_type; static constexpr auto get_sparsity_type() { return sparsity_type; } using sparse_matrix_type = @@ -271,7 +272,7 @@ class csr_matrix template > - csr_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + csr_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) : sparse_matrix_type(handle, structure){}; @@ -284,13 +285,20 @@ class csr_matrix void initialize_sparsity(NZType nnz) { sparse_matrix_type::initialize_sparsity(nnz); - this->structure_.get()->initialize_sparsity(nnz); + this->structure_.initialize_sparsity(nnz); } /** * Return a view of the structure underlying this matrix * @return */ - structure_view_type structure_view() { return this->structure_.get()->view(); } + structure_view_type structure_view() + { + if constexpr (get_sparsity_type() == SparsityType::OWNING) { + return this->structure_.view(); + } else { + return this->structure_; + } + } }; } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/core/cublas_macros.hpp b/cpp/include/raft/core/cublas_macros.hpp index 855c1228f7..5c56240ccf 100644 --- a/cpp/include/raft/core/cublas_macros.hpp +++ b/cpp/include/raft/core/cublas_macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include ///@todo: enable this once we have logger enabled -//#include +// #include #include diff --git a/cpp/include/raft/core/cusolver_macros.hpp b/cpp/include/raft/core/cusolver_macros.hpp index 8f7caf65f3..4477d32118 100644 --- a/cpp/include/raft/core/cusolver_macros.hpp +++ b/cpp/include/raft/core/cusolver_macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include #include ///@todo: enable this once logging is enabled -//#include +// #include #include #include diff --git a/cpp/include/raft/core/cusparse_macros.hpp b/cpp/include/raft/core/cusparse_macros.hpp index 8a9aab55f7..21a25ae28c 100644 --- a/cpp/include/raft/core/cusparse_macros.hpp +++ b/cpp/include/raft/core/cusparse_macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include #include ///@todo: enable this once logging is enabled -//#include +// #include #define _CUSPARSE_ERR_TO_STR(err) \ case err: return #err; diff --git a/cpp/include/raft/core/detail/logger.hpp b/cpp/include/raft/core/detail/logger.hpp index 619fb89452..532aee4d90 100644 --- a/cpp/include/raft/core/detail/logger.hpp +++ b/cpp/include/raft/core/detail/logger.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ */ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in future releases." \ - " Please use the version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in future releases." \ + " Please use the version instead.") #include diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index bfb47437ad..390acea697 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,40 @@ #define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE #endif +// The RAFT_INLINE_CONDITIONAL is a conditional inline specifier that removes +// the inline specification when RAFT_COMPILED is defined. +// +// When RAFT_COMPILED is not defined, functions may be defined in multiple +// translation units and we do not want that to lead to linker errors. +// +// When RAFT_COMPILED is defined, this serves two purposes: +// +// 1. It triggers a multiple definition error message when memory_pool-inl.hpp +// (for instance) is accidentally included in multiple translation units. +// +// 2. We function definitions to be non-inline, because non-inline functions +// symbols are always exported in the object symbol table. For inline functions, +// the compiler may elide the external symbol, which results in linker errors. +#ifdef RAFT_COMPILED +#define RAFT_INLINE_CONDITIONAL +#else +#define RAFT_INLINE_CONDITIONAL inline +#endif // RAFT_COMPILED + +// The RAFT_WEAK_FUNCTION specificies that: +// +// 1. A function may be defined in multiple translation units (like inline) +// +// 2. Must still emit an external symbol (unlike inline). This enables declaring +// a function signature in an `-ext` header and defining it in a source file. +// +// From +// https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html#Common-Function-Attributes: +// +// "The weak attribute causes a declaration of an external symbol to be emitted +// as a weak symbol rather than a global." +#define RAFT_WEAK_FUNCTION __attribute__((weak)) + /** * Some macro magic to remove optional parentheses of a macro argument. * See https://stackoverflow.com/a/62984543 diff --git a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp index df89811636..d0aea4168e 100644 --- a/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp +++ b/cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp @@ -74,7 +74,7 @@ namespace numpy_serializer { #if RAFT_SYSTEM_LITTLE_ENDIAN == 1 #define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_LITTLE_ENDIAN_CHAR -#else // RAFT_SYSTEM_LITTLE_ENDIAN == 1 +#else // RAFT_SYSTEM_LITTLE_ENDIAN == 1 #define RAFT_NUMPY_HOST_ENDIAN_CHAR RAFT_NUMPY_BIG_ENDIAN_CHAR #endif // RAFT_SYSTEM_LITTLE_ENDIAN == 1 @@ -110,11 +110,9 @@ struct header_t { }; template -struct is_complex : std::false_type { -}; +struct is_complex : std::false_type {}; template -struct is_complex> : std::true_type { -}; +struct is_complex> : std::true_type {}; template , bool> = true> inline dtype_t get_numpy_dtype() diff --git a/cpp/include/raft/core/detail/nvtx.hpp b/cpp/include/raft/core/detail/nvtx.hpp index 4a16ec81bd..e734c99029 100644 --- a/cpp/include/raft/core/detail/nvtx.hpp +++ b/cpp/include/raft/core/detail/nvtx.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,18 +18,18 @@ #include -namespace raft::common::nvtx::detail { - #ifdef NVTX_ENABLED #include #include #include -#include +#include #include #include #include +namespace raft::common::nvtx::detail { + /** * @brief An internal struct to store associated state with the color * generator @@ -191,7 +191,11 @@ inline void pop_range() nvtxDomainRangePop(domain_store::value()); } -#else // NVTX_ENABLED +} // namespace raft::common::nvtx::detail + +#else // NVTX_ENABLED + +namespace raft::common::nvtx::detail { template inline void push_range(const char* format, Args... args) @@ -203,6 +207,6 @@ inline void pop_range() { } -#endif // NVTX_ENABLED - } // namespace raft::common::nvtx::detail + +#endif // NVTX_ENABLED diff --git a/cpp/include/raft/core/detail/span.hpp b/cpp/include/raft/core/detail/span.hpp index 20500d618b..e6ccb8535c 100644 --- a/cpp/include/raft/core/detail/span.hpp +++ b/cpp/include/raft/core/detail/span.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,8 +37,7 @@ template struct extent_value_t : public std::integral_constant< std::size_t, - Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> { -}; + Count != dynamic_extent ? Count : (Extent != dynamic_extent ? Extent - Offset : Extent)> {}; /*! * If N is dynamic_extent, the extent of the returned span E is also @@ -47,31 +46,25 @@ struct extent_value_t template struct extent_as_bytes_value_t : public std::integral_constant { -}; + Extent == dynamic_extent ? Extent : sizeof(T) * Extent> {}; template struct is_allowed_extent_conversion_t : public std::integral_constant { -}; + From == To || From == dynamic_extent || To == dynamic_extent> {}; template struct is_allowed_element_type_conversion_t - : public std::integral_constant::value> { -}; + : public std::integral_constant::value> {}; template -struct is_span_oracle_t : std::false_type { -}; +struct is_span_oracle_t : std::false_type {}; template -struct is_span_oracle_t> : std::true_type { -}; +struct is_span_oracle_t> : std::true_type {}; template -struct is_span_t : public is_span_oracle_t::type> { -}; +struct is_span_t : public is_span_oracle_t::type> {}; template _RAFT_HOST_DEVICE constexpr auto lexicographical_compare(InputIt1 first1, diff --git a/cpp/include/raft/core/device_coo_matrix.hpp b/cpp/include/raft/core/device_coo_matrix.hpp index b1e9ca30fc..ce016dd5e0 100644 --- a/cpp/include/raft/core/device_coo_matrix.hpp +++ b/cpp/include/raft/core/device_coo_matrix.hpp @@ -79,8 +79,7 @@ template using device_coordinate_structure_view = coordinate_structure_view; template -struct is_device_coo_matrix : std::false_type { -}; +struct is_device_coo_matrix : std::false_type {}; template struct is_device_coo_matrix< device_coo_matrix> - : std::true_type { -}; + : std::true_type {}; template constexpr bool is_device_coo_matrix_v = is_device_coo_matrix::value; @@ -174,16 +172,15 @@ auto make_device_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] handle raft handle for managing expensive device resources - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_device_coo_matrix(raft::resources const& handle, - device_coordinate_structure_view structure_) + device_coordinate_structure_view structure) { - return device_sparsity_preserving_coo_matrix( - handle, - std::make_shared>(structure_)); + return device_sparsity_preserving_coo_matrix(handle, + structure); } /** @@ -212,16 +209,15 @@ auto make_device_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_device_coo_matrix_view( - ElementType* ptr, device_coordinate_structure_view structure_) + ElementType* ptr, device_coordinate_structure_view structure) { return device_coo_matrix_view( - raft::device_span(ptr, structure_.get_nnz()), - std::make_shared>(structure_)); + raft::device_span(ptr, structure.get_nnz()), structure); } /** @@ -251,19 +247,17 @@ auto make_device_coo_matrix_view( * @tparam ColType * @tparam NZType * @param[in] elements a device span containing nonzero matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return */ template auto make_device_coo_matrix_view( raft::device_span elements, - device_coordinate_structure_view structure_) + device_coordinate_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return device_coo_matrix_view( - elements, - std::make_shared>(structure_)); + return device_coo_matrix_view(elements, structure); } /** @@ -338,7 +332,7 @@ auto make_device_coordinate_structure(raft::resources const& handle, * @return a sparsity-preserving coordinate structural view */ template -auto make_device_coo_structure_view( +auto make_device_coordinate_structure_view( RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) { return device_coordinate_structure_view( @@ -376,10 +370,10 @@ auto make_device_coo_structure_view( * @return a sparsity-preserving coordinate structural view */ template -auto make_device_coo_structure_view(raft::device_span rows, - raft::device_span cols, - RowType n_rows, - ColType n_cols) +auto make_device_coordinate_structure_view(raft::device_span rows, + raft::device_span cols, + RowType n_rows, + ColType n_cols) { return device_coordinate_structure_view(rows, cols, n_rows, n_cols); } diff --git a/cpp/include/raft/core/device_csr_matrix.hpp b/cpp/include/raft/core/device_csr_matrix.hpp index 59cabacf6d..869034e925 100644 --- a/cpp/include/raft/core/device_csr_matrix.hpp +++ b/cpp/include/raft/core/device_csr_matrix.hpp @@ -46,8 +46,7 @@ using device_sparsity_owning_csr_matrix = csr_matrix; template -struct is_device_csr_matrix : std::false_type { -}; +struct is_device_csr_matrix : std::false_type {}; template struct is_device_csr_matrix< device_csr_matrix> - : std::true_type { -}; + : std::true_type {}; template constexpr bool is_device_csr_matrix_v = is_device_csr_matrix::value; @@ -189,7 +187,7 @@ auto make_device_csr_matrix(raft::device_resources const& handle, * @tparam IndicesType * @tparam NZType * @param[in] handle raft handle for managing expensive device resources - * @param[in] structure_ a sparsity-preserving compressed structural view + * @param[in] structure a sparsity-preserving compressed structural view * @return a sparsity-preserving sparse matrix in compressed (csr) format */ template auto make_device_csr_matrix( raft::device_resources const& handle, - device_compressed_structure_view structure_) + device_compressed_structure_view structure) { return device_sparsity_preserving_csr_matrix( - handle, - std::make_shared>( - structure_)); + handle, structure); } /** @@ -232,7 +228,7 @@ auto make_device_csr_matrix( * @tparam IndicesType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on device (size nnz) - * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @param[in] structure a sparsity-preserving compressed sparse structural view * @return a sparsity-preserving csr matrix view */ template auto make_device_csr_matrix_view( - ElementType* ptr, device_compressed_structure_view structure_) + ElementType* ptr, device_compressed_structure_view structure) { return device_csr_matrix_view( - raft::device_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::device_span(ptr, structure.get_nnz()), structure); } /** @@ -273,7 +269,7 @@ auto make_device_csr_matrix_view( * @tparam IndicesType * @tparam NZType * @param[in] elements device span containing array of matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving structural view + * @param[in] structure a sparsity-preserving structural view * @return a sparsity-preserving csr matrix view */ template auto make_device_csr_matrix_view( raft::device_span elements, - device_compressed_structure_view structure_) + device_compressed_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return device_csr_matrix_view( - elements, std::make_shared(structure_)); + return device_csr_matrix_view(elements, structure); } /** @@ -365,7 +360,7 @@ auto make_device_compressed_structure(raft::device_resources const& handle, * @return a sparsity-preserving compressed structural view */ template -auto make_device_csr_structure_view( +auto make_device_compressed_structure_view( IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) { return device_compressed_structure_view( @@ -408,9 +403,9 @@ auto make_device_csr_structure_view( * */ template -auto make_device_csr_structure_view(raft::device_span indptr, - raft::device_span indices, - IndicesType n_cols) +auto make_device_compressed_structure_view(raft::device_span indptr, + raft::device_span indices, + IndicesType n_cols) { return device_compressed_structure_view(indptr, indices, n_cols); } diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f72ae36d64..c1898a3f09 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -45,11 +45,9 @@ template >; template -struct is_device_mdspan : std::false_type { -}; +struct is_device_mdspan : std::false_type {}; template -struct is_device_mdspan : std::bool_constant { -}; +struct is_device_mdspan : std::bool_constant {}; /** * @\brief Boolean to determine if template type T is either raft::device_mdspan or a derived type @@ -64,11 +62,9 @@ template using is_output_device_mdspan_t = is_device_mdspan>; template -struct is_managed_mdspan : std::false_type { -}; +struct is_managed_mdspan : std::false_type {}; template -struct is_managed_mdspan : std::bool_constant { -}; +struct is_managed_mdspan : std::bool_constant {}; /** * @\brief Boolean to determine if template type T is either raft::managed_mdspan or a derived type @@ -259,6 +255,36 @@ auto make_device_matrix_view(ElementType* ptr, IndexType n_rows, IndexType n_col return device_matrix_view{ptr, extents}; } +/** + * @brief Create a 2-dim mdspan instance for device pointer with a strided layout + * that is restricted to stride 1 in the trailing dimension. It's + * expected that the given layout policy match the layout of the underlying + * pointer. + * @tparam ElementType the data type of the matrix elements + * @tparam IndexType the index type of the extents + * @tparam LayoutPolicy policy for strides and layout ordering + * @param[in] ptr on device to wrap + * @param[in] n_rows number of rows in pointer + * @param[in] n_cols number of columns in pointer + * @param[in] stride leading dimension / stride of data + */ +template +auto make_device_strided_matrix_view(ElementType* ptr, + IndexType n_rows, + IndexType n_cols, + IndexType stride) +{ + constexpr auto is_row_major = std::is_same_v; + IndexType stride0 = is_row_major ? (stride > 0 ? stride : n_cols) : 1; + IndexType stride1 = is_row_major ? 1 : (stride > 0 ? stride : n_rows); + + assert(is_row_major ? stride0 >= n_cols : stride1 >= n_rows); + matrix_extent extents{n_rows, n_cols}; + + auto layout = make_strided_layout(extents, std::array{stride0, stride1}); + return device_matrix_view{ptr, layout}; +} + /** * @brief Create a 1-dim mdspan instance for device pointer. * @tparam ElementType the data type of the vector elements diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp index df6b39a368..1cab36561a 100644 --- a/cpp/include/raft/core/device_resources.hpp +++ b/cpp/include/raft/core/device_resources.hpp @@ -69,7 +69,7 @@ class device_resources : public resources { } device_resources(const device_resources& handle) : resources{handle} {} - device_resources(device_resources&&) = delete; + device_resources(device_resources&&) = delete; device_resources& operator=(device_resources&&) = delete; /** @@ -246,7 +246,7 @@ class stream_syncer { handle_.sync_stream_pool(); } - stream_syncer(const stream_syncer& other) = delete; + stream_syncer(const stream_syncer& other) = delete; stream_syncer& operator=(const stream_syncer& other) = delete; private: diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 02efebec9e..2a6b5657e2 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -39,7 +39,7 @@ class handle_t : public raft::device_resources { handle_t(const handle_t& handle) : device_resources{handle} {} - handle_t(handle_t&&) = delete; + handle_t(handle_t&&) = delete; handle_t& operator=(handle_t&&) = delete; /** diff --git a/cpp/include/raft/core/host_coo_matrix.hpp b/cpp/include/raft/core/host_coo_matrix.hpp index 45ec278a7d..32e7a9e3c4 100644 --- a/cpp/include/raft/core/host_coo_matrix.hpp +++ b/cpp/include/raft/core/host_coo_matrix.hpp @@ -78,8 +78,7 @@ template using host_coordinate_structure_view = coordinate_structure_view; template -struct is_host_coo_matrix : std::false_type { -}; +struct is_host_coo_matrix : std::false_type {}; template struct is_host_coo_matrix< host_coo_matrix> - : std::true_type { -}; + : std::true_type {}; template constexpr bool is_host_coo_matrix_v = is_host_coo_matrix::value; @@ -173,15 +171,15 @@ auto make_host_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] handle raft handle for managing expensive resources - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_host_coo_matrix(raft::resources const& handle, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { - return host_sparsity_preserving_coo_matrix( - handle, std::make_shared>(structure_)); + return host_sparsity_preserving_coo_matrix(handle, + structure); } /** @@ -210,15 +208,15 @@ auto make_host_coo_matrix(raft::resources const& handle, * @tparam ColType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return a sparsity-preserving sparse matrix in coordinate (coo) format */ template auto make_host_coo_matrix_view(ElementType* ptr, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { return host_coo_matrix_view( - raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::host_span(ptr, structure.get_nnz()), structure); } /** @@ -248,17 +246,16 @@ auto make_host_coo_matrix_view(ElementType* ptr, * @tparam ColType * @tparam NZType * @param[in] elements a host span containing nonzero matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving coordinate structural view + * @param[in] structure a sparsity-preserving coordinate structural view * @return */ template auto make_host_coo_matrix_view(raft::host_span elements, - host_coordinate_structure_view structure_) + host_coordinate_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return host_coo_matrix_view(elements, - std::make_shared(structure_)); + return host_coo_matrix_view(elements, structure); } /** @@ -333,7 +330,7 @@ auto make_host_coordinate_structure(raft::resources const& handle, * @return a sparsity-preserving coordinate structural view */ template -auto make_host_coo_structure_view( +auto make_host_coordinate_structure_view( RowType* rows, ColType* cols, RowType n_rows, ColType n_cols, NZType nnz) { return host_coordinate_structure_view( @@ -371,10 +368,10 @@ auto make_host_coo_structure_view( * @return a sparsity-preserving coordinate structural view */ template -auto make_host_coo_structure_view(raft::host_span rows, - raft::host_span cols, - RowType n_rows, - ColType n_cols) +auto make_host_coordinate_structure_view(raft::host_span rows, + raft::host_span cols, + RowType n_rows, + ColType n_cols) { return host_coordinate_structure_view(rows, cols, n_rows, n_cols); } diff --git a/cpp/include/raft/core/host_csr_matrix.hpp b/cpp/include/raft/core/host_csr_matrix.hpp index 437f60814e..86199335f2 100644 --- a/cpp/include/raft/core/host_csr_matrix.hpp +++ b/cpp/include/raft/core/host_csr_matrix.hpp @@ -45,8 +45,7 @@ using host_sparsity_owning_csr_matrix = csr_matrix; template -struct is_host_csr_matrix : std::false_type { -}; +struct is_host_csr_matrix : std::false_type {}; template struct is_host_csr_matrix< host_csr_matrix> - : std::true_type { -}; + : std::true_type {}; template constexpr bool is_host_csr_matrix_v = is_host_csr_matrix::value; @@ -189,20 +187,18 @@ auto make_host_csr_matrix(raft::resources const& handle, * @tparam IndicesType * @tparam NZType * @param[in] handle raft handle for managing expensive resources - * @param[in] structure_ a sparsity-preserving compressed structural view + * @param[in] structure a sparsity-preserving compressed structural view * @return a sparsity-preserving sparse matrix in compressed (csr) format */ template -auto make_host_csr_matrix( - raft::resources const& handle, - host_compressed_structure_view structure_) +auto make_host_csr_matrix(raft::resources const& handle, + host_compressed_structure_view structure) { return host_sparsity_preserving_csr_matrix( - handle, - std::make_shared>(structure_)); + handle, structure); } /** @@ -231,7 +227,7 @@ auto make_host_csr_matrix( * @tparam IndicesType * @tparam NZType * @param[in] ptr a pointer to array of nonzero matrix elements on host (size nnz) - * @param[in] structure_ a sparsity-preserving compressed sparse structural view + * @param[in] structure a sparsity-preserving compressed sparse structural view * @return a sparsity-preserving csr matrix view */ template auto make_host_csr_matrix_view( - ElementType* ptr, host_compressed_structure_view structure_) + ElementType* ptr, host_compressed_structure_view structure) { return host_csr_matrix_view( - raft::host_span(ptr, structure_.get_nnz()), std::make_shared(structure_)); + raft::host_span(ptr, structure.get_nnz()), structure); } /** @@ -272,7 +268,7 @@ auto make_host_csr_matrix_view( * @tparam IndicesType * @tparam NZType * @param[in] elements host span containing array of matrix elements (size nnz) - * @param[in] structure_ a sparsity-preserving structural view + * @param[in] structure a sparsity-preserving structural view * @return a sparsity-preserving csr matrix view */ template auto make_host_csr_matrix_view( raft::host_span elements, - host_compressed_structure_view structure_) + host_compressed_structure_view structure) { - RAFT_EXPECTS(elements.size() == structure_.get_nnz(), + RAFT_EXPECTS(elements.size() == structure.get_nnz(), "Size of elements must be equal to the nnz from the structure"); - return host_csr_matrix_view( - elements, std::make_shared(structure_)); + return host_csr_matrix_view(elements, structure); } /** @@ -365,7 +360,7 @@ auto make_host_compressed_structure(raft::resources const& handle, * @return a sparsity-preserving compressed structural view */ template -auto make_host_csr_structure_view( +auto make_host_compressed_structure_view( IndptrType* indptr, IndicesType* indices, IndptrType n_rows, IndicesType n_cols, NZType nnz) { return host_compressed_structure_view( @@ -408,9 +403,9 @@ auto make_host_csr_structure_view( * */ template -auto make_host_csr_structure_view(raft::host_span indptr, - raft::host_span indices, - IndicesType n_cols) +auto make_host_compressed_structure_view(raft::host_span indptr, + raft::host_span indices, + IndicesType n_cols) { return host_compressed_structure_view(indptr, indices, n_cols); } diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index a6cdec7a84..9a675680ac 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -37,11 +37,9 @@ template >; template -struct is_host_mdspan : std::false_type { -}; +struct is_host_mdspan : std::false_type {}; template -struct is_host_mdspan : std::bool_constant { -}; +struct is_host_mdspan : std::bool_constant {}; /** * @\brief Boolean to determine if template type T is either raft::host_mdspan or a derived type diff --git a/cpp/include/raft/core/interruptible.hpp b/cpp/include/raft/core/interruptible.hpp index 0cc4af2bbf..62e481a801 100644 --- a/cpp/include/raft/core/interruptible.hpp +++ b/cpp/include/raft/core/interruptible.hpp @@ -172,10 +172,10 @@ class interruptible { inline void cancel() noexcept { continue_.clear(std::memory_order_relaxed); } // don't allow the token to leave the shared_ptr - interruptible(interruptible const&) = delete; - interruptible(interruptible&&) = delete; + interruptible(interruptible const&) = delete; + interruptible(interruptible&&) = delete; auto operator=(interruptible const&) -> interruptible& = delete; - auto operator=(interruptible&&) -> interruptible& = delete; + auto operator=(interruptible&&) -> interruptible& = delete; private: /** Global registry of thread-local cancellation stores. */ diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 192d160d45..2e0d1117a1 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -32,8 +32,8 @@ struct KeyValuePair { typedef _Key Key; ///< Key data type typedef _Value Value; ///< Value data type - Key key; ///< Item key - Value value; ///< Item value + Key key; ///< Item key + Value value; ///< Item value /// Constructor RAFT_INLINE_FUNCTION KeyValuePair() {} diff --git a/cpp/include/raft/core/logger-ext.hpp b/cpp/include/raft/core/logger-ext.hpp new file mode 100644 index 0000000000..8fd29cf1d6 --- /dev/null +++ b/cpp/include/raft/core/logger-ext.hpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // std::unique_ptr +#include // RAFT_INLINE_CONDITIONAL +#include // std::string +#include // std::unordered_map + +namespace raft { + +static const std::string RAFT_NAME = "raft"; +static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); + +namespace detail { +RAFT_INLINE_CONDITIONAL std::string format(const char* fmt, ...); +} +/** + * @brief The main Logging class for raft library. + * + * This class acts as a thin wrapper over the underlying `spdlog` interface. The + * design is done in this way in order to avoid us having to also ship `spdlog` + * header files in our installation. + * + * @todo This currently only supports logging to stdout. Need to add support in + * future to add custom loggers as well [Issue #2046] + */ +class logger { + public: + // @todo setting the logger once per process with + logger(std::string const& name_ = ""); + /** + * @brief Singleton method to get the underlying logger object + * + * @return the singleton logger object + */ + static logger& get(std::string const& name = ""); + + /** + * @brief Set the logging level. + * + * Only messages with level equal or above this will be printed + * + * @param[in] level logging level + * + * @note The log level will actually be set only if the input is within the + * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll + * be ignored. See documentation of decisiontree for how this gets used + */ + void set_level(int level); + + /** + * @brief Set the logging pattern + * + * @param[in] pattern the pattern to be set. Refer this link + * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + * to know the right syntax of this pattern + */ + void set_pattern(const std::string& pattern); + + /** + * @brief Register a callback function to be run in place of usual log call + * + * @param[in] callback the function to be run on all logged messages + */ + void set_callback(void (*callback)(int lvl, const char* msg)); + + /** + * @brief Register a flush function compatible with the registered callback + * + * @param[in] flush the function to use when flushing logs + */ + void set_flush(void (*flush)()); + + /** + * @brief Tells whether messages will be logged for the given log level + * + * @param[in] level log level to be checked for + * @return true if messages will be logged for this level, else false + */ + bool should_log_for(int level) const; + /** + * @brief Query for the current log level + * + * @return the current log level + */ + int get_level() const; + + /** + * @brief Get the current logging pattern + * @return the pattern + */ + std::string get_pattern() const; + + /** + * @brief Main logging method + * + * @param[in] level logging level of this message + * @param[in] fmt C-like format string, followed by respective params + */ + void log(int level, const char* fmt, ...); + + /** + * @brief Flush logs by calling flush on underlying logger + */ + void flush(); + + ~logger(); + + private: + logger(); + // pimpl pattern: + // https://learn.microsoft.com/en-us/cpp/cpp/pimpl-for-compile-time-encapsulation-modern-cpp?view=msvc-170 + class impl; + std::unique_ptr pimpl; + static inline std::unordered_map> log_map; +}; // class logger + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-inl.hpp b/cpp/include/raft/core/logger-inl.hpp new file mode 100644 index 0000000000..fcfa1f1333 --- /dev/null +++ b/cpp/include/raft/core/logger-inl.hpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include "logger-macros.hpp" +// The logger-ext.hpp file contains the class declaration of the logger class. +// In this case, it is okay to include the logger-ext.hpp file because it +// contains no RAFT_EXPLICIT template instantiations. +#include "logger-ext.hpp" + +#define SPDLOG_HEADER_ONLY +#include +#include // RAFT_INLINE_CONDITIONAL +#include // NOLINT +#include // NOLINT + +namespace raft { + +namespace detail { + +inline std::string format(const char* fmt, va_list& vl) +{ + va_list vl_copy; + va_copy(vl_copy, vl); + int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); + assert(length >= 0); + std::vector buf(length + 1); + std::vsnprintf(buf.data(), length + 1, fmt, vl); + return std::string(buf.data()); +} + +RAFT_INLINE_CONDITIONAL std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} + +inline int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + +} // namespace detail + +class logger::impl { // defined privately here + // ... all private data and functions: all of these + // can now change without recompiling callers ... + public: + std::shared_ptr sink; + std::shared_ptr spdlogger; + std::string cur_pattern; + int cur_level; + + impl(std::string const& name_ = "") + : sink{std::make_shared()}, + spdlogger{std::make_shared(name_, sink)}, + cur_pattern() + { + } +}; // class logger::impl + +RAFT_INLINE_CONDITIONAL logger::logger(std::string const& name_) : pimpl(new impl(name_)) +{ + set_pattern(default_log_pattern); + set_level(RAFT_ACTIVE_LEVEL); +} + +RAFT_INLINE_CONDITIONAL logger& logger::get(std::string const& name) +{ + if (log_map.find(name) == log_map.end()) { log_map[name] = std::make_shared(name); } + return *log_map[name]; +} + +RAFT_INLINE_CONDITIONAL void logger::set_level(int level) +{ + level = raft::detail::convert_level_to_spdlog(level); + pimpl->spdlogger->set_level(static_cast(level)); +} + +RAFT_INLINE_CONDITIONAL void logger::set_pattern(const std::string& pattern) +{ + pimpl->cur_pattern = pattern; + pimpl->spdlogger->set_pattern(pattern); +} + +RAFT_INLINE_CONDITIONAL void logger::set_callback(void (*callback)(int lvl, const char* msg)) +{ + pimpl->sink->set_callback(callback); +} + +RAFT_INLINE_CONDITIONAL void logger::set_flush(void (*flush)()) { pimpl->sink->set_flush(flush); } + +RAFT_INLINE_CONDITIONAL bool logger::should_log_for(int level) const +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + return pimpl->spdlogger->should_log(level_e); +} + +RAFT_INLINE_CONDITIONAL int logger::get_level() const +{ + auto level_e = pimpl->spdlogger->level(); + return RAFT_LEVEL_TRACE - static_cast(level_e); +} + +RAFT_INLINE_CONDITIONAL std::string logger::get_pattern() const { return pimpl->cur_pattern; } + +RAFT_INLINE_CONDITIONAL void logger::log(int level, const char* fmt, ...) +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + // explicit check to make sure that we only expand messages when required + if (pimpl->spdlogger->should_log(level_e)) { + va_list vl; + va_start(vl, fmt); + auto msg = raft::detail::format(fmt, vl); + va_end(vl); + pimpl->spdlogger->log(level_e, msg); + } +} + +RAFT_INLINE_CONDITIONAL void logger::flush() { pimpl->spdlogger->flush(); } + +RAFT_INLINE_CONDITIONAL logger::~logger() {} + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-macros.hpp b/cpp/include/raft/core/logger-macros.hpp new file mode 100644 index 0000000000..5ddb072067 --- /dev/null +++ b/cpp/include/raft/core/logger-macros.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#endif + +/** + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ + */ +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE_VEC(ptr, len) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + print_vector(#ptr, ptr, len, ss); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ diff --git a/cpp/include/raft/core/logger.hpp b/cpp/include/raft/core/logger.hpp index 3984ec042a..59968ff5e5 100644 --- a/cpp/include/raft/core/logger.hpp +++ b/cpp/include/raft/core/logger.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,310 +15,10 @@ */ #pragma once -#ifndef __RAFT_RT_LOGGER -#define __RAFT_RT_LOGGER +#include "logger-macros.hpp" -#include +#include "logger-ext.hpp" -#include - -#include -#include -#include -#include -#include - -#include - -#define SPDLOG_HEADER_ONLY -#include -#include -#include // NOLINT -#include // NOLINT - -/** - * @defgroup logging levels used in raft - * - * @note exactly match the corresponding ones (but reverse in terms of value) - * in spdlog for wrapping purposes - * - * @{ - */ -#define RAFT_LEVEL_TRACE 6 -#define RAFT_LEVEL_DEBUG 5 -#define RAFT_LEVEL_INFO 4 -#define RAFT_LEVEL_WARN 3 -#define RAFT_LEVEL_ERROR 2 -#define RAFT_LEVEL_CRITICAL 1 -#define RAFT_LEVEL_OFF 0 -/** @} */ - -#if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#if !defined(RAFT_COMPILED) +#include "logger-inl.hpp" #endif - -namespace raft { - -static const std::string RAFT_NAME = "raft"; -static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); - -namespace detail { - -/** - * @defgroup CStringFormat Expand a C-style format string - * - * @brief Expands C-style formatted string into std::string - * - * @param[in] fmt format string - * @param[in] vl respective values for each of format modifiers in the string - * - * @return the expanded `std::string` - * - * @{ - */ -inline std::string format(const char* fmt, va_list& vl) -{ - va_list vl_copy; - va_copy(vl_copy, vl); - int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); - assert(length >= 0); - std::vector buf(length + 1); - std::vsnprintf(buf.data(), length + 1, fmt, vl); - return std::string(buf.data()); -} - -inline std::string format(const char* fmt, ...) -{ - va_list vl; - va_start(vl, fmt); - std::string str = format(fmt, vl); - va_end(vl); - return str; -} -/** @} */ - -inline int convert_level_to_spdlog(int level) -{ - level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); - return RAFT_LEVEL_TRACE - level; -} - -} // namespace detail - -/** - * @brief The main Logging class for raft library. - * - * This class acts as a thin wrapper over the underlying `spdlog` interface. The - * design is done in this way in order to avoid us having to also ship `spdlog` - * header files in our installation. - * - * @todo This currently only supports logging to stdout. Need to add support in - * future to add custom loggers as well [Issue #2046] - */ -class logger { - public: - // @todo setting the logger once per process with - logger(std::string const& name_ = "") - : sink{std::make_shared()}, - spdlogger{std::make_shared(name_, sink)}, - cur_pattern() - { - set_pattern(default_log_pattern); - set_level(RAFT_ACTIVE_LEVEL); - } - /** - * @brief Singleton method to get the underlying logger object - * - * @return the singleton logger object - */ - static logger& get(std::string const& name = "") - { - if (log_map.find(name) == log_map.end()) { - log_map[name] = std::make_shared(name); - } - return *log_map[name]; - } - - /** - * @brief Set the logging level. - * - * Only messages with level equal or above this will be printed - * - * @param[in] level logging level - * - * @note The log level will actually be set only if the input is within the - * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll - * be ignored. See documentation of decisiontree for how this gets used - */ - void set_level(int level) - { - level = raft::detail::convert_level_to_spdlog(level); - spdlogger->set_level(static_cast(level)); - } - - /** - * @brief Set the logging pattern - * - * @param[in] pattern the pattern to be set. Refer this link - * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting - * to know the right syntax of this pattern - */ - void set_pattern(const std::string& pattern) - { - cur_pattern = pattern; - spdlogger->set_pattern(pattern); - } - - /** - * @brief Register a callback function to be run in place of usual log call - * - * @param[in] callback the function to be run on all logged messages - */ - void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); } - - /** - * @brief Register a flush function compatible with the registered callback - * - * @param[in] flush the function to use when flushing logs - */ - void set_flush(void (*flush)()) { sink->set_flush(flush); } - - /** - * @brief Tells whether messages will be logged for the given log level - * - * @param[in] level log level to be checked for - * @return true if messages will be logged for this level, else false - */ - bool should_log_for(int level) const - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - return spdlogger->should_log(level_e); - } - - /** - * @brief Query for the current log level - * - * @return the current log level - */ - int get_level() const - { - auto level_e = spdlogger->level(); - return RAFT_LEVEL_TRACE - static_cast(level_e); - } - - /** - * @brief Get the current logging pattern - * @return the pattern - */ - std::string get_pattern() const { return cur_pattern; } - - /** - * @brief Main logging method - * - * @param[in] level logging level of this message - * @param[in] fmt C-like format string, followed by respective params - */ - void log(int level, const char* fmt, ...) - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - // explicit check to make sure that we only expand messages when required - if (spdlogger->should_log(level_e)) { - va_list vl; - va_start(vl, fmt); - auto msg = raft::detail::format(fmt, vl); - va_end(vl); - spdlogger->log(level_e, msg); - } - } - - /** - * @brief Flush logs by calling flush on underlying logger - */ - void flush() { spdlogger->flush(); } - - ~logger() {} - - private: - logger(); - - static inline std::unordered_map> log_map; - std::shared_ptr sink; - std::shared_ptr spdlogger; - std::string cur_pattern; - int cur_level; -}; // class logger - -}; // namespace raft - -/** - * @defgroup loggerMacros Helper macros for dealing with logging - * @{ - */ -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE_VEC(ptr, len) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - print_vector(#ptr, ptr, len, ss); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) -#define RAFT_LOG_DEBUG(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_DEBUG(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) -#define RAFT_LOG_INFO(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_INFO(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) -#define RAFT_LOG_WARN(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_WARN(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) -#define RAFT_LOG_ERROR(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_ERROR(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) -#define RAFT_LOG_CRITICAL(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_CRITICAL(fmt, ...) void(0) -#endif -/** @} */ - -#endif \ No newline at end of file diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 61c1b500e6..c7350a978c 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -25,11 +25,11 @@ #include #include +#include #include #include #include #include -#include namespace raft { /** @@ -45,23 +45,21 @@ namespace raft { template class array_interface { /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() noexcept { return static_cast(this)->view(); } /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() const noexcept { return static_cast(this)->view(); } }; namespace detail { template -struct is_array_interface : std::false_type { -}; +struct is_array_interface : std::false_type {}; template struct is_array_interface().view())>> - : std::bool_constant().view())>> { -}; + : std::bool_constant().view())>> {}; template using is_array_interface_t = is_array_interface>; @@ -76,16 +74,13 @@ inline constexpr bool is_array_interface_v = is_array_interface -struct is_array_interface : std::true_type { -}; +struct is_array_interface : std::true_type {}; template -struct is_array_interface : detail::is_array_interface_t { -}; +struct is_array_interface : detail::is_array_interface_t {}; template struct is_array_interface : std::conditional_t, is_array_interface, - std::false_type> { -}; + std::false_type> {}; /** * @\brief Boolean to determine if variadic template types Tn are raft::array_interface * or derived type or any type that has a member function `view()` that returns either @@ -108,7 +103,8 @@ inline constexpr bool is_array_interface_v = is_array_interface::value; * template. * * - Most of the constructors from the reference implementation is removed to make sure - * CUDA stream is honorred. + * CUDA stream is honored. Note that this class is not coupled to CUDA and therefore + * will only be used in the case where the device variant is used. * * - unique_size is not implemented, which is still working in progress in the proposal * @@ -177,9 +173,9 @@ class mdarray constexpr mdarray(mdarray&&) noexcept(std::is_nothrow_move_constructible::value) = default; - constexpr auto operator =(mdarray const&) noexcept( + constexpr auto operator=(mdarray const&) noexcept( std::is_nothrow_copy_assignable::value) -> mdarray& = default; - constexpr auto operator =(mdarray&&) noexcept( + constexpr auto operator=(mdarray&&) noexcept( std::is_nothrow_move_assignable::value) -> mdarray& = default; ~mdarray() noexcept(std::is_nothrow_destructible::value) = default; @@ -201,8 +197,9 @@ class mdarray #endif // RAFT_MDARRAY_CTOR_CONSTEXPR /** - * @brief The only constructor that can create storage, this is to make sure CUDA stream is being - * used. + * @brief The only constructor that can create storage, raft::resources is accepted + * so that the device implementation can make sure the relevant CUDA stream is + * being used for allocation. */ RAFT_MDARRAY_CTOR_CONSTEXPR mdarray(raft::resources const& handle, mapping_type const& m, @@ -220,11 +217,11 @@ class mdarray #undef RAFT_MDARRAY_CTOR_CONSTEXPR /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() noexcept { return view_type(c_.data(), map_, cp_.make_accessor_policy()); } /** - * @brief Get a mdspan that can be passed down to CUDA kernels. + * @brief Get an mdspan */ auto view() const noexcept { diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 1c69cdd973..cd9ca26ed9 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -85,28 +85,22 @@ template *); template -struct is_mdspan : std::false_type { -}; +struct is_mdspan : std::false_type {}; template struct is_mdspan()))>> - : std::true_type { -}; + : std::true_type {}; template -struct is_input_mdspan : std::false_type { -}; +struct is_input_mdspan : std::false_type {}; template struct is_input_mdspan()))>> - : std::bool_constant> { -}; + : std::bool_constant> {}; template -struct is_output_mdspan : std::false_type { -}; +struct is_output_mdspan : std::false_type {}; template struct is_output_mdspan()))>> - : std::bool_constant> { -}; + : std::bool_constant> {}; template using is_mdspan_t = is_mdspan>; diff --git a/cpp/include/raft/core/nvtx.hpp b/cpp/include/raft/core/nvtx.hpp index 09a41f10a6..57338c32c7 100644 --- a/cpp/include/raft/core/nvtx.hpp +++ b/cpp/include/raft/core/nvtx.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -144,9 +144,9 @@ class range { ~range() { pop_range(); } /* This object is not meant to be touched. */ - range(const range&) = delete; - range(range&&) = delete; - auto operator=(const range&) -> range& = delete; + range(const range&) = delete; + range(range&&) = delete; + auto operator=(const range&) -> range& = delete; auto operator=(range&&) -> range& = delete; static auto operator new(std::size_t) -> void* = delete; static auto operator new[](std::size_t) -> void* = delete; diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 35ae3d715f..ebc41e0f8e 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::resource { class device_memory_resource : public resource { @@ -72,4 +73,4 @@ inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_ { res.add_resource_factory(std::make_shared(mr)); }; -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp index cf302e25f9..2dc4eb1f9d 100644 --- a/cpp/include/raft/core/resource/resource_types.hpp +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -42,7 +42,7 @@ enum resource_type { THRUST_POLICY, // thrust execution policy WORKSPACE_RESOURCE, // rmm device memory resource - LAST_KEY // reserved for the last key + LAST_KEY // reserved for the last key }; /** @@ -83,6 +83,8 @@ class resource_factory { * @return resource instance */ virtual resource* make_resource() = 0; + + virtual ~resource_factory() {} }; /** diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 64e281e934..e0f51b61b4 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -18,6 +18,7 @@ #include "resource/resource_types.hpp" #include #include +#include // RAFT_EXPECTS #include #include #include @@ -67,7 +68,7 @@ class resources { * Note that this does not create any new resources. */ resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} - resources(resources&&) = delete; + resources(resources&&) = delete; resources& operator=(resources&&) = delete; /** @@ -128,4 +129,4 @@ class resources { mutable std::vector factories_; mutable std::vector resources_; }; -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/core/span.hpp b/cpp/include/raft/core/span.hpp index 188d58c896..a896ba1977 100644 --- a/cpp/include/raft/core/span.hpp +++ b/cpp/include/raft/core/span.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ class span { constexpr span(span&& other) noexcept = default; constexpr auto operator=(span const& other) noexcept -> span& = default; - constexpr auto operator=(span&& other) noexcept -> span& = default; + constexpr auto operator=(span&& other) noexcept -> span& = default; constexpr auto begin() const noexcept -> iterator { return data(); } diff --git a/cpp/include/raft/core/sparse_types.hpp b/cpp/include/raft/core/sparse_types.hpp index 207cc944d2..a14944ed5b 100644 --- a/cpp/include/raft/core/sparse_types.hpp +++ b/cpp/include/raft/core/sparse_types.hpp @@ -109,7 +109,7 @@ class sparse_matrix_view { * Return a view of the structure underlying this matrix * @return */ - structure_view_type get_structure() { return structure_view_; } + structure_view_type structure_view() { return structure_view_; } /** * Return a span of the nonzero elements of the matrix @@ -158,18 +158,19 @@ class sparse_matrix { using container_policy_type = ContainerPolicy; using container_type = typename container_policy_type::container_type; + // constructor that owns the data and the structure sparse_matrix(raft::resources const& handle, row_type n_rows, col_type n_cols, nnz_type nnz = 0) noexcept(std::is_nothrow_default_constructible_v) - : structure_{std::make_shared(handle, n_rows, n_cols, nnz)}, - cp_{}, - c_elements_{cp_.create(handle, 0)} {}; + : structure_{handle, n_rows, n_cols, nnz}, cp_{}, c_elements_{cp_.create(handle, 0)} {}; // Constructor that owns the data but not the structure - sparse_matrix(raft::resources const& handle, std::shared_ptr structure) noexcept( + // This constructor is only callable with a `structure_type == *_structure_view` + // which makes it okay to copy + sparse_matrix(raft::resources const& handle, structure_type structure) noexcept( std::is_nothrow_default_constructible_v) - : structure_{structure}, cp_{}, c_elements_{cp_.create(handle, structure.get()->get_nnz())} {}; + : structure_{structure}, cp_{}, c_elements_{cp_.create(handle, structure_.get_nnz())} {}; constexpr sparse_matrix(sparse_matrix const&) noexcept( std::is_nothrow_copy_constructible_v) = default; @@ -187,7 +188,7 @@ class sparse_matrix { raft::span get_elements() { - return raft::span(c_elements_.data(), structure_view().get_nnz()); + return raft::span(c_elements_.data(), structure_.get_nnz()); } /** @@ -209,7 +210,7 @@ class sparse_matrix { } protected: - std::shared_ptr structure_; + structure_type structure_; container_policy_type cp_; container_type c_elements_; }; diff --git a/cpp/include/raft/core/temporary_device_buffer.hpp b/cpp/include/raft/core/temporary_device_buffer.hpp index 194471c5de..4baa7e9597 100644 --- a/cpp/include/raft/core/temporary_device_buffer.hpp +++ b/cpp/include/raft/core/temporary_device_buffer.hpp @@ -55,10 +55,10 @@ class temporary_device_buffer { static constexpr bool is_const_pointer_ = std::is_const_v; public: - temporary_device_buffer(temporary_device_buffer const&) = delete; + temporary_device_buffer(temporary_device_buffer const&) = delete; temporary_device_buffer& operator=(temporary_device_buffer const&) = delete; - constexpr temporary_device_buffer(temporary_device_buffer&&) = default; + constexpr temporary_device_buffer(temporary_device_buffer&&) = default; constexpr temporary_device_buffer& operator=(temporary_device_buffer&&) = default; /** diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh index 7a4fe0ce83..68e843c6f5 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -30,13 +30,11 @@ namespace raft::distance::detail::ops { // This pattern is described in: // https://en.cppreference.com/w/cpp/types/void_t template -struct has_cutlass_op : std::false_type { -}; +struct has_cutlass_op : std::false_type {}; // Specialization recognizes types that do support CUTLASS template struct has_cutlass_op().get_cutlass_op())>> - : std::true_type { -}; + : std::true_type {}; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index aaf3052892..2154aa560c 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -16,13 +16,26 @@ #pragma once +#include +#include #include +#include +// #include +#include +#include #include #include namespace raft::distance::kernels::detail { +template +using dense_input_matrix_view_t = raft::device_matrix_view; +template +using dense_output_matrix_view_t = raft::device_matrix_view; +template +using csr_input_matrix_view_t = raft::device_csr_matrix_view; + /** * Base class for general Gram matrices * A Gram matrix is the Hermitian matrix of inner probucts G_ik = @@ -37,14 +50,135 @@ namespace raft::distance::kernels::detail { */ template class GramMatrixBase { + protected: cublasHandle_t cublas_handle; + bool legacy_interface; public: - GramMatrixBase(cublasHandle_t cublas_handle) : cublas_handle(cublas_handle){}; + GramMatrixBase() : legacy_interface(false){}; + [[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle) + : cublas_handle(cublas_handle), legacy_interface(true){}; virtual ~GramMatrixBase(){}; /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + /** Convenience function to evaluate the Gram matrix for two vector sets. + * Vector sets are provided in Matrix format + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void operator()(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1 = nullptr, + math_t* norm_x2 = nullptr) + { + evaluate(handle, x1, x2, out, norm_x1, norm_x2); + } + + // unfortunately, 'evaluate' cannot be templatized as it needs to be virtual + + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + /** Evaluate the Gram matrix for two vector sets using simple dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + virtual void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + linear(handle, x1, x2, out); + } + + /** Evaluate the Gram matrix for two vector sets using simple dot product. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 @@ -55,29 +189,26 @@ class GramMatrixBase { * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - virtual void operator()(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1 = 0, - int ld2 = 0, - int ld_out = 0) + [[deprecated]] virtual void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { - if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } - if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } - if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } - evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); } - /** Evaluate the Gram matrix for two vector sets using simple dot product. + /** Convenience function to evaluate the Gram matrix for two vector sets. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 @@ -88,30 +219,30 @@ class GramMatrixBase { * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) + * @param ld1 leading dimension of x1 + * @param ld2 leading dimension of x2 + * @param ld_out leading dimension of out */ - virtual void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void operator()(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1 = 0, + int ld2 = 0, + int ld_out = 0) { - linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); + ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor."); + if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } + if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } + if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } + evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); } - // private: - // The following methods should be private, they are kept public to avoid: - // "error: The enclosing parent function ("distance") for an extended - // __device__ lambda cannot have private or protected access within its class" - + protected: /** Calculates the Gram matrix using simple dot product between vector sets. * * out = x1 * x2 @@ -131,17 +262,17 @@ class GramMatrixBase { * @param ld2 leading dimension of x2 * @param ld_out leading dimension of out */ - void linear(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void linear(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { math_t alpha = 1.0; math_t beta = 0.0; @@ -182,37 +313,198 @@ class GramMatrixBase { } } - /** Calculates the Gram matrix using Euclidean distance. + protected: + bool get_is_row_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_row_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(1) == 1); + } + + bool get_is_col_major(dense_output_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + bool get_is_col_major(dense_input_matrix_view_t matrix) + { + return (matrix.stride(0) == 1); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 * * Can be used as a building block for more complex kernel functions. * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] */ - virtual void distance(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - raft::distance::distance( - raft::device_resources(stream), x1, x2, out, n1, n2, n_cols, is_row_major); + void linear(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x1, x2 and out do not match"); + + // check dimensions + int n1 = out.extent(0); + int n2 = out.extent(1); + int n_cols = x1.extent(1); + ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + // extract major stride + int ld1 = is_row_major ? x1.stride(0) : x1.stride(1); + int ld2 = is_row_major ? x2.stride(0) : x2.stride(1); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + + math_t alpha = 1.0; + math_t beta = 0.0; + if (is_row_major) { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + true, + false, + n2, + n1, + n_cols, + &alpha, + x2.data_handle(), + ld2, + x1.data_handle(), + ld1, + &beta, + out.data_handle(), + ld_out, + handle.get_stream()); + } else { + // #TODO: Use mdspan-based API when stride-capable + // https://github.com/rapidsai/raft/issues/875 + raft::linalg::gemm(handle, + false, + true, + n1, + n2, + n_cols, + &alpha, + x1.data_handle(), + ld1, + x2.data_handle(), + ld2, + &beta, + out.data_handle(), + ld_out, + handle.get_stream()); + } + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(x2) && get_is_row_major(out); + bool is_col_major = get_is_col_major(x2) && get_is_col_major(out); + ASSERT(is_row_major || is_col_major, + "GramMatrix leading dimensions for x2 and out do not match"); + + // check dimensions + auto x1_structure = x1.structure_view(); + ASSERT(x1_structure.get_n_rows() == out.extent(0), + "GramMatrix input matrix dimensions for x1 and out do not match"); + ASSERT(x2.extent(0) == out.extent(1), + "GramMatrix input matrix dimensions for x2 and out do not match"); + ASSERT(x2.extent(1) == x1_structure.get_n_cols(), + "GramMatrix input matrix dimensions for x1 and x2 do not match"); + + math_t alpha = 1.0; + math_t beta = 0.0; + + raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out); + } + + /** Calculates the Gram matrix using simple dot product between vector sets. + * + * out = x1 * x2 + * + * Can be used as a building block for more complex kernel functions. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + */ + void linear(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out) + { + // check is_row_major consistency + bool is_row_major = get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + int minor_out = is_row_major ? out.extent(1) : out.extent(0); + ASSERT(ld_out == minor_out, "Sparse linear Kernel distance does not support ld_out parameter"); + + auto x1_structure = x1.structure_view(); + auto x2_structure = x2.structure_view(); + raft::sparse::distance::distances_config_t dist_config(handle); + + // switch a,b based on is_row_major + if (!is_row_major) { + dist_config.a_nrows = x2_structure.get_n_rows(); + dist_config.a_ncols = x2_structure.get_n_cols(); + dist_config.a_nnz = x2_structure.get_nnz(); + dist_config.a_indptr = const_cast(x2_structure.get_indptr().data()); + dist_config.a_indices = const_cast(x2_structure.get_indices().data()); + dist_config.a_data = const_cast(x2.get_elements().data()); + dist_config.b_nrows = x1_structure.get_n_rows(); + dist_config.b_ncols = x1_structure.get_n_cols(); + dist_config.b_nnz = x1_structure.get_nnz(); + dist_config.b_indptr = const_cast(x1_structure.get_indptr().data()); + dist_config.b_indices = const_cast(x1_structure.get_indices().data()); + dist_config.b_data = const_cast(x1.get_elements().data()); + } else { + dist_config.a_nrows = x1_structure.get_n_rows(); + dist_config.a_ncols = x1_structure.get_n_cols(); + dist_config.a_nnz = x1_structure.get_nnz(); + dist_config.a_indptr = const_cast(x1_structure.get_indptr().data()); + dist_config.a_indices = const_cast(x1_structure.get_indices().data()); + dist_config.a_data = const_cast(x1.get_elements().data()); + dist_config.b_nrows = x2_structure.get_n_rows(); + dist_config.b_ncols = x2_structure.get_n_cols(); + dist_config.b_nnz = x2_structure.get_nnz(); + dist_config.b_indptr = const_cast(x2_structure.get_indptr().data()); + dist_config.b_indices = const_cast(x2_structure.get_indices().data()); + dist_config.b_data = const_cast(x2.get_elements().data()); + } + + raft::sparse::distance::pairwiseDistance( + out.data_handle(), dist_config, raft::distance::DistanceType::InnerProduct, 0.0); } }; + }; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh index 1aa6809bcd..bb3ff1c2f5 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,19 +26,35 @@ namespace raft::distance::kernels::detail { template class KernelFactory { public: - static GramMatrixBase* create(KernelParams params, cublasHandle_t cublas_handle) + static GramMatrixBase* create(KernelParams params) { GramMatrixBase* res; // KernelParams is not templated, we convert the parameters to math_t here: math_t coef0 = params.coef0; math_t gamma = params.gamma; switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(cublas_handle); break; + case LINEAR: res = new GramMatrixBase(); break; + case POLYNOMIAL: res = new PolynomialKernel(params.degree, gamma, coef0); break; + case TANH: res = new TanhKernel(gamma, coef0); break; + case RBF: res = new RBFKernel(gamma); break; + default: throw raft::exception("Kernel not implemented"); + } + return res; + } + + [[deprecated]] static GramMatrixBase* create(KernelParams params, cublasHandle_t handle) + { + GramMatrixBase* res; + // KernelParams is not templated, we convert the parameters to math_t here: + math_t coef0 = params.coef0; + math_t gamma = params.gamma; + switch (params.kernel) { + case LINEAR: res = new GramMatrixBase(handle); break; case POLYNOMIAL: - res = new PolynomialKernel(params.degree, gamma, coef0, cublas_handle); + res = new PolynomialKernel(params.degree, gamma, coef0, handle); break; - case TANH: res = new TanhKernel(gamma, coef0, cublas_handle); break; - case RBF: res = new RBFKernel(gamma); break; + case TANH: res = new TanhKernel(gamma, coef0, handle); break; + case RBF: res = new RBFKernel(gamma, handle); break; default: throw raft::exception("Kernel not implemented"); } return res; diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index d1465efdb0..7ff886c677 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -17,10 +17,12 @@ #pragma once #include "gram_matrix.cuh" -#include +#include #include #include +#include +#include namespace raft::distance::kernels::detail { @@ -100,6 +102,38 @@ __global__ void tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t ga } } +/** Epiloge function for rbf kernel using expansion. + * + * Calculates output_ij = exp(-gain * (norm_x_i + norm_y_j - 2*input_ij)); + * + * Intended usage + * - input is the product of two matrices X and Y input_ij = sum_k X_ik * Y_jk + * - norm_x_i = l2_norm(x_i), where x_i is the i-th row of matrix X + * - norm_y_j = l2_norm(y_j), where y_j is the j-th row of matrix Y + * + * @param inout device vector in column major format, size [ld * cols] + * @param ld leading dimension of the inout buffer + * @param rows number of rows (rows <= ld) + * @param cols number of columns + * @param norm_x l2-norm of X's rows + * @param norm_y l2-norm of Y's rows + * @param gain + */ +template +__global__ void rbf_kernel_expanded( + math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) +{ + for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; + tidy += blockDim.y * gridDim.y) { + math_t norm_y_val = norm_y[tidy]; + for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; + tidx += blockDim.x * gridDim.x) { + inout[tidx + tidy * ld] = + exp(-1.0 * gain * (norm_x[tidx] + norm_y_val - inout[tidx + tidy * ld] * 2)); + } + } +} + /** * Create a kernel matrix using polynomial kernel function. */ @@ -138,11 +172,42 @@ class PolynomialKernel : public GramMatrixBase { * @param exponent * @param gain * @param offset - * @param cublas_handle */ - PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t cublas_handle) - : GramMatrixBase(cublas_handle), exponent(exponent), gain(gain), offset(offset) + PolynomialKernel(exp_t exponent, math_t gain, math_t offset) + : GramMatrixBase(), exponent(exponent), gain(gain), offset(offset) + { + } + + [[deprecated]] PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), exponent(exponent), gain(gain), offset(offset) + { + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); } /** Evaluate kernel matrix using polynomial kernel. @@ -150,32 +215,84 @@ class PolynomialKernel : public GramMatrixBase { * output[i,k] = (gain* + offset)^exponent, * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using polynomial kernel. + * + * output[i,k] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*cols] + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); GramMatrixBase::linear( x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); applyKernel(out, ld_out, n1, n2, is_row_major, stream); @@ -216,10 +333,11 @@ class TanhKernel : public GramMatrixBase { * @tparam math_t floating point type * @param gain * @param offset - * @param cublas_handle */ - TanhKernel(math_t gain, math_t offset, cublasHandle_t cublas_handle) - : GramMatrixBase(cublas_handle), gain(gain), offset(offset) + TanhKernel(math_t gain, math_t offset) : GramMatrixBase(), gain(gain), offset(offset) {} + + [[deprecated]] TanhKernel(math_t gain, math_t offset, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain), offset(offset) { } @@ -229,12 +347,87 @@ class TanhKernel : public GramMatrixBase { * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and < , > denotes dot product. * - * @param [in] x1 device array of vectors, - * size [n1*n_cols] + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate kernel matrix using tanh kernel. + * + * output_[i + k*n1] = (gain* + offset)^exponent, + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and < , > denotes dot product. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 unused. + * @param norm_x2 unused. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel( + out.data_handle(), ld_out, out.extent(0), out.extent(1), is_row_major, handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. + * + * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 - * @param [in] x2 device array of vectors, - * size [n2*n_cols] + * @param [in] n_cols number of columns (features) in x1 and x2 + * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row @@ -244,18 +437,20 @@ class TanhKernel : public GramMatrixBase { * @param ld2 leading dimension of x2 (usually it is n2) * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); GramMatrixBase::linear( x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); applyKernel(out, ld_out, n1, n2, is_row_major, stream); @@ -269,21 +464,23 @@ template class RBFKernel : public GramMatrixBase { math_t gain; - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) + void applyKernel(math_t* inout, + int ld, + int rows, + int cols, + math_t* norm_x1, + math_t* norm_x2, + bool is_row_major, + cudaStream_t stream) { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - rbf_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, gain); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - rbf_kernel<<>>(inout, ld, n1, n2, gain); - } + int n1 = is_row_major ? cols : rows; + int n2 = is_row_major ? rows : cols; + math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1; + math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2; + rbf_kernel_expanded<<>>(inout, ld, n1, n2, norm_n1, norm_n2, gain); } public: @@ -295,65 +492,234 @@ class RBFKernel : public GramMatrixBase { * @tparam math_t floating point type * @param gain */ - RBFKernel(math_t gain) : GramMatrixBase(NULL), gain(gain) {} + RBFKernel(math_t gain) : GramMatrixBase(), gain(gain) {} + + [[deprecated]] RBFKernel(math_t gain, cublasHandle_t handle) + : GramMatrixBase(handle), gain(gain) + { + } + + void matrixRowNormL2(raft::device_resources const& handle, + dense_input_matrix_view_t matrix, + math_t* target) + { + bool is_row_major = GramMatrixBase::get_is_row_major(matrix); + int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); + int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); + ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); + raft::linalg::rowNorm(target, + matrix.data_handle(), + matrix.extent(1), + matrix.extent(0), + raft::linalg::NormType::L2Norm, + is_row_major, + handle.get_stream()); + } + + void matrixRowNormL2(raft::device_resources const& handle, + csr_input_matrix_view_t matrix, + math_t* target) + { + auto matrix_structure = matrix.structure_view(); + raft::sparse::linalg::rowNormCsr(handle, + matrix_structure.get_indptr().data(), + matrix.get_elements().data(), + matrix_structure.get_nnz(), + matrix_structure.get_n_rows(), + target, + raft::linalg::NormType::L2Norm); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 dense device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + dense_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.extent(0), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } /** Evaluate kernel matrix using RBF kernel. * * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 dense device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + dense_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.extent(0), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } + + /** Evaluate kernel matrix using RBF kernel. + * + * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), + * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector + * in the x2 set, and | | euclidean distance. + * + * @param [in] handle raft handle + * @param [in] x1 csr device matrix view, size [n1*n_cols] + * @param [in] x2 csr device matrix view, size [n2*n_cols] + * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] + * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. + * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. + */ + void evaluate(raft::device_resources const& handle, + csr_input_matrix_view_t x1, + csr_input_matrix_view_t x2, + dense_output_matrix_view_t out, + math_t* norm_x1, + math_t* norm_x2) + { + cudaStream_t stream = handle.get_stream(); + + // lazy compute norms if not given + rmm::device_uvector tmp_norm_x1(0, stream); + rmm::device_uvector tmp_norm_x2(0, stream); + if (norm_x1 == nullptr) { + tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); + norm_x1 = tmp_norm_x1.data(); + matrixRowNormL2(handle, x1, norm_x1); + } + if (norm_x2 == nullptr) { + tmp_norm_x2.reserve(x2.structure_view().get_n_rows(), stream); + norm_x2 = tmp_norm_x2.data(); + matrixRowNormL2(handle, x2, norm_x2); + } + + // compute L2expanded + bool is_row_major = GramMatrixBase::get_is_row_major(out); + int ld_out = is_row_major ? out.stride(0) : out.stride(1); + GramMatrixBase::linear(handle, x1, x2, out); + applyKernel(out.data_handle(), + ld_out, + out.extent(0), + out.extent(1), + norm_x1, + norm_x2, + is_row_major, + handle.get_stream()); + } + + /** Evaluate the Gram matrix using the legacy interface. * * @param [in] x1 device array of vectors, size [n1*n_cols] * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of features in x1 and x2 + * @param [in] n_cols number of columns (features) in x1 and x2 * @param [in] x2 device array of vectors, size [n2*n_cols] * @param [in] n2 number vectors in x2 * @param [out] out device buffer to store the Gram matrix, size [n1*n2] * @param [in] is_row_major whether the input and output matrices are in row * major format * @param [in] stream cuda stream - * @param ld1 leading dimension of x1, currently only ld1 == n1 is supported - * @param ld2 leading dimension of x2, currently only ld2 == n2 is supported - * @param ld_out leading dimension of out, only ld_out == n1 is supported + * @param ld1 leading dimension of x1 (usually it is n1) + * @param ld2 leading dimension of x2 (usually it is n2) + * @param ld_out leading dimension of out (usually it is n1) */ - void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) + [[deprecated]] void evaluate(const math_t* x1, + int n1, + int n_cols, + const math_t* x2, + int n2, + math_t* out, + bool is_row_major, + cudaStream_t stream, + int ld1, + int ld2, + int ld_out) { + ASSERT(GramMatrixBase::legacy_interface, + "Legacy interface can only be used with legacy ctor."); int minor1 = is_row_major ? n_cols : n1; int minor2 = is_row_major ? n_cols : n2; int minor_out = is_row_major ? n2 : n1; ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); - distance(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - /** Customize distance function withe RBF epilogue */ - void distance(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { math_t gain = this->gain; using index_t = int64_t; - auto fin_op = [gain] __device__(math_t d_val, index_t idx) { return exp(-gain * d_val); }; + rbf_fin_op fin_op{gain}; raft::distance::distance // raft::exp +#include // HD + +namespace raft::distance::kernels::detail { + +/** @brief: Final op for Gram matrix with RBF kernel. + * + * Calculates output = e^(-gain * in) + * + */ +template +struct rbf_fin_op { + OutT gain; + + explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} + + template + HDI OutT operator()(OutT d_val, Args... unused_args) + { + return raft::exp(-gain * d_val); + } +}; // struct rbf_fin_op + +} // namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh index 55da634145..5a33c9ce4a 100644 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ b/cpp/include/raft/distance/detail/masked_distance_base.cuh @@ -217,7 +217,7 @@ struct MaskedDistances : public BaseClass { } // tile_idx_n } // idx_g rowEpilog_op(tile_idx_m); - } // tile_idx_m + } // tile_idx_m } private: diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index c6b09be31e..58b5daa8ca 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -18,7 +18,7 @@ #include // ceildiv #include // RAFT_CUDA_TRY -#include // size_t +#include // size_t namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh new file mode 100644 index 0000000000..dd58ab4328 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::identity_op +#include // ops::* +#include // ops::has_cutlass_op +#include // rbf_fin_op +#include // pairwise_matrix_params +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::distance::detail { + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) RAFT_EXPLICIT; + +}; // namespace raft::distance::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + extern template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ + +// The following two instances are used in the RBF kernel object. Note the use of int64_t for the +// index type. +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +// Rest of instances +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh new file mode 100644 index 0000000000..bb4422735b --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -0,0 +1,130 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/detail/pairwise_matrix/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and src/distance/detail/pairwise_matrix/dispatch_*.cu. + +namespace raft::distance::detail { + +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. +template +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + + // On CUDA 12: + // - always execute normal kernel + // + // On CUDA 11 and below: + // - execute CUTLASS-based kernel on SM_80 and above + // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; + + constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; + constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); + + if constexpr (is_ctk_12 || cutlass_op_unavailable) { + // Always execute legacy kernels on CUDA 12 + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); + } else { + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); + + // Get pointer to SM60 kernel to determine the runtime architecture of the + // current system. Other methods to determine the architecture (that do not + // require a pointer) can be error prone. See: + // https://github.com/NVIDIA/cub/issues/545 + auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); + void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); + + if (cutlass_range.contains(runtime_arch)) { + // If device is SM_80 or later, use CUTLASS-based kernel. + pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); + } else { + // Reuse kernel wrapper that we obtained above. This avoids performing the + // dispatch twice. + sm60_wrapper.launch(distance_op, params, stream); + } + } +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index e04b56ee8a..4a52b7ebe7 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,123 +15,10 @@ */ #pragma once -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/distance/specializations/detail/. Previously, - * raft::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. - */ - -#include // ops::has_cutlass_op -#include // dispatch_sm60 -#include // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh -// and the specializations in src/distance/distance/specializations/detail/. - -namespace raft::distance::detail { - -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance specializations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - -template -void pairwise_matrix_instantiation_point(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - // On CUDA 12: - // - always execute normal kernel - // - // On CUDA 11 and below: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel below SM_80 - namespace arch = raft::util::arch; - - constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; - constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - - if constexpr (is_ctk_12 || cutlass_op_unavailable) { - // Always execute legacy kernels on CUDA 12 - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } -} - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - pairwise_matrix_instantiation_point(distance_op, params, stream); -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "dispatch-inl.cuh" +#endif -}; // namespace raft::distance::detail +#ifdef RAFT_COMPILED +#include "dispatch-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h index 67c01448dc..ebe6d0c80a 100644 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,8 +57,8 @@ namespace threadblock { /// /// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator /// -template diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh new file mode 100644 index 0000000000..3f7f2b0a23 --- /dev/null +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -0,0 +1,1065 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::device_matrix_view +#include // raft::identity_op +#include // raft::resources +#include // rbf_fin_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT +#include // rmm::device_uvector + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +}; // namespace distance +}; // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * raft::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + extern template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + extern template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh new file mode 100644 index 0000000000..3399443765 --- /dev/null +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace raft { +namespace distance { + +/** + * @defgroup pairwise_distance pointer-based pairwise distance prims + * @{ + */ + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccT and returns the output in OutT. It's signature is + * as follows:
OutT fin_op(AccT in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + detail::distance( + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) +{ + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param x first set of points (size m*k) + * @param y second set of points (size n*k) + * @return number of bytes needed in workspace + * + * @note If the specified DistT doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + + return getWorkspaceSize( + x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + auto dispatch = [&](auto distance_type) { + auto worksize = getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + detail::distance( + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); + }; + + switch (metric) { + case DistanceType::Canberra: + dispatch(std::integral_constant{}); + break; + case DistanceType::CorrelationExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::CosineExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HammingUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::HellingerExpanded: + dispatch(std::integral_constant{}); + break; + case raft::distance::DistanceType::InnerProduct: + dispatch(std::integral_constant{}); + break; + case DistanceType::JensenShannon: + dispatch(std::integral_constant{}); + break; + case DistanceType::KLDivergence: + dispatch(std::integral_constant{}); + break; + case DistanceType::L1: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Expanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtExpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2SqrtUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::L2Unexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::Linf: + dispatch(std::integral_constant{}); + break; + case DistanceType::LpUnexpanded: + dispatch(std::integral_constant{}); + break; + case DistanceType::RusselRaoExpanded: + dispatch(std::integral_constant{}); + break; + default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param isRowMajor whether the matrices are row-major or col-major + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) +{ + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + pairwise_distance( + handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); +} + +/** @} */ + +/** + * \defgroup distance_mdspan Pairwise distance functions + * @{ + */ + +/** + * @brief Evaluate pairwise distances for the simple use case. + * + * Note: Only contiguous row- or column-major layouts supported currently. + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * + * raft::raft::device_resources handle; + * int n_samples = 5000; + * int n_features = 50; + * + * auto input = raft::make_device_matrix(handle, n_samples, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto output = raft::make_device_matrix(handle, n_samples, n_samples); + * + * raft::random::make_blobs(handle, input.view(), labels.view()); + * auto metric = raft::distance::DistanceType::L2SqrtExpanded; + * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); + * @endcode + * + * @tparam DistanceType which distance to evaluate + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type + * @param handle raft handle for managing expensive resources + * @param x first set of points (size n*k) + * @param y second set of points (size m*k) + * @param dist output distance matrix (size n*m) + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + + constexpr auto is_rowmajor = std::is_same_v; + + distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + is_rowmajor, + metric_arg); +} + +/** + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam IdxT indexing type + * @param handle raft handle for managing expensive resources + * @param x first matrix of points (size mxk) + * @param y second matrix of points (size nxk) + * @param dist output distance matrix (size mxn) + * @param metric distance metric + * @param metric_arg metric argument (used for Minkowski distance) + */ +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) +{ + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); + RAFT_EXPECTS(dist.extent(0) == x.extent(0), + "Number of rows in output must be equal to " + "number of rows in X"); + RAFT_EXPECTS(dist.extent(1) == y.extent(0), + "Number of columns in output must be equal to " + "number of rows in Y"); + + RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); + RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); + RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); + + constexpr auto rowmajor = std::is_same_v; + + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); + + pairwise_distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + metric, + rowmajor, + metric_arg); +} + +/** @} */ + +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 5216902635..de70cd4691 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -13,470 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __DISTANCE_H -#define __DISTANCE_H - #pragma once -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { - -/** - * @defgroup pairwise_distance pointer-based pairwise distance prims - * @{ - */ - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) -{ - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points (size m*k) - * @param y second set of points (size n*k) - * @return number of bytes needed in workspace - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const raft::device_matrix_view x, - const raft::device_matrix_view y) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - - return getWorkspaceSize( - x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); - }; - - switch (metric) { - case DistanceType::Canberra: - dispatch(std::integral_constant{}); - break; - case DistanceType::CorrelationExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::CosineExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HammingUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HellingerExpanded: - dispatch(std::integral_constant{}); - break; - case raft::distance::DistanceType::InnerProduct: - dispatch(std::integral_constant{}); - break; - case DistanceType::JensenShannon: - dispatch(std::integral_constant{}); - break; - case DistanceType::KLDivergence: - dispatch(std::integral_constant{}); - break; - case DistanceType::L1: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Expanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Unexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::Linf: - dispatch(std::integral_constant{}); - break; - case DistanceType::LpUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::RusselRaoExpanded: - dispatch(std::integral_constant{}); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - Index_ m, - Index_ n, - Index_ k, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - pairwise_distance( - handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); -} - -/** @} */ - -/** - * \defgroup distance_mdspan Pairwise distance functions - * @{ - */ - -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * #include - * - * raft::raft::device_resources handle; - * int n_samples = 5000; - * int n_features = 50; - * - * auto input = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto output = raft::make_device_matrix(handle, n_samples, n_samples); - * - * raft::random::make_blobs(handle, input.view(), labels.view()); - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); - * @endcode - * - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - InType metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param handle raft handle for managing expensive resources - * @param x first matrix of points (size mxk) - * @param y second matrix of points (size nxk) - * @param dist output distance matrix (size mxn) - * @param metric distance metric - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - raft::distance::DistanceType metric, - Type metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); - - constexpr auto rowmajor = std::is_same_v; - - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - - pairwise_distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - metric, - rowmajor, - metric_arg); -} - -/** @} */ - -}; // namespace distance -}; // namespace raft +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "distance-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "distance-ext.cuh" #endif diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index 4060147f1d..d17ef358ee 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -74,8 +74,6 @@ inline bool is_min_close(DistanceType metric) bool select_min; switch (metric) { case DistanceType::InnerProduct: - case DistanceType::CosineExpanded: - case DistanceType::CorrelationExpanded: // Similarity metrics have the opposite meaning, i.e. nearest neighbors are those with larger // similarity (See the same logic at cpp/include/raft/sparse/spatial/detail/knn.cuh:362 // {perform_k_selection}) diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh new file mode 100644 index 0000000000..05732c1f3f --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::device_resources +#include // raft::KeyValuePair +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh new file mode 100644 index 0000000000..698d287f87 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -0,0 +1,206 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __FUSED_L2_NN_H +#define __FUSED_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** + * \ingroup fused_l2_nn + * @{ + */ +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] pairRedOp reduction operation on key value pairs + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NN(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + // When k is smaller than 32, the Policy4x4 results in redundant calculations + // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead + // that uses tiles with a smaller value of k. + bool is_skinny = k < 32; + + size_t bytes = sizeof(DataT) * k; + auto px = reinterpret_cast(x); + auto py = reinterpret_cast(y); + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } else { + if (is_skinny) { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } else { + detail::fusedL2NNImpl::Policy, + ReduceOpT>( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); + } + } +} + +/** + * @brief Wrapper around fusedL2NN with minimum reduction operators. + * + * fusedL2NN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * This should be preferred to the more generic API when possible, in order to + * reduce compilation times for users of the shared library. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances (e.g. raft::KeyValuePair) or store only the min + * distances. + * @tparam IdxT indexing arithmetic type + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) +{ + MinAndDistanceReduceOp redOp; + KVPMinReduce pairRedOp; + + fusedL2NN( + min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index e832bcb020..b1a3551323 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -13,218 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __FUSED_L2_NN_H -#define __FUSED_L2_NN_H - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize( - raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize(min, m, maxVal, redOp, handle.get_stream()); -} - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } -} - -/** - * @brief Wrapper around fusedL2NN with minimum reduction operators. - * - * fusedL2NN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * This should be preferred to the more generic API when possible, in order to - * reduce compilation times for users of the shared library. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedL2NN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); -} - -/** @} */ - -} // namespace distance -} // namespace raft +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "fused_l2_nn-inl.cuh" +#endif +#ifdef RAFT_COMPILED +#include "fused_l2_nn-ext.cuh" #endif diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh new file mode 100644 index 0000000000..1bcd7d8dba --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance { + +/** + * \defgroup fused_l2_nn Fused 1-nearest neighbors + * @{ + */ + +template +using KVPMinReduce = detail::KVPMinReduceImpl; + +template +using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; + +template +using MinReduceOp = detail::MinReduceOpImpl; + +/** @} */ + +/** + * Initialize array using init value from reduction op + */ +template +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + detail::initialize(min, m, maxVal, redOp, handle.get_stream()); +} + +} // namespace raft::distance diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh index 5944534be7..ed0b6848ae 100644 --- a/cpp/include/raft/distance/specializations.cuh +++ b/cpp/include/raft/distance/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __DISTANCE_SPECIALIZATIONS_H -#define __DISTANCE_SPECIALIZATIONS_H - #pragma once -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 63ae6580b4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -# This template manages all files in this directory, apart from -# inner_product.cuh and kernels.cuh. - - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -start_template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -""" - -extern_template = """ -extern template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); -""" - -end_template = """} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - - - - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - # cosine uses CUTLASS for SM80+ - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - # L2 expanded uses CUTLASS for SM80+ - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -for op_instance in op_instances: - path = fill_in("path_prefix.cuh", op_instance) - with open(path, "w") as f: - f.write(start_template) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "raft::identity_op", - } - - text = fill_in(extern_template, instance) - - f.write(text) - - f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh deleted file mode 100644 index 276c85e5f6..0000000000 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - float, - float, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - double, - double, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh deleted file mode 100644 index f019f678df..0000000000 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - float, - float, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - double, - double, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh deleted file mode 100644 index dcde4ec286..0000000000 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::cosine_distance_op, - int, - double, - double, - raft::identity_op>(ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh deleted file mode 100644 index 1d6964fbce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - float, - float, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - double, - double, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh deleted file mode 100644 index f96a06f919..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - float, - float, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - double, - double, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/inner_product.cuh b/cpp/include/raft/distance/specializations/detail/inner_product.cuh deleted file mode 100644 index d97d678928..0000000000 --- a/cpp/include/raft/distance/specializations/detail/inner_product.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh deleted file mode 100644 index 0b58646582..0000000000 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - float, - float, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - double, - double, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh deleted file mode 100644 index 75c9c023e8..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -extern template class raft::distance::kernels::detail::GramMatrixBase; -extern template class raft::distance::kernels::detail::GramMatrixBase; - -extern template class raft::distance::kernels::detail::PolynomialKernel; -extern template class raft::distance::kernels::detail::PolynomialKernel; - -extern template class raft::distance::kernels::detail::TanhKernel; -extern template class raft::distance::kernels::detail::TanhKernel; - -// These are somehow missing a kernel definition which is causing a compile error -// extern template class raft::distance::kernels::detail::RBFKernel; -// extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh deleted file mode 100644 index 5c164e0fd4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh deleted file mode 100644 index 870627d909..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh deleted file mode 100644 index ee3207bcce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_exp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh deleted file mode 100644 index 1fbf57632b..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh deleted file mode 100644 index 388d3bf439..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l_inf_distance_op, - int, - double, - double, - raft::identity_op>(ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh deleted file mode 100644 index d8e86ce6f2..0000000000 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh deleted file mode 100644 index 4803fb8ab0..0000000000 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - float, - float, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - double, - double, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index a34f696e9e..ed0b6848ae 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -13,22 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 88e1216635..9588a7f329 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,115 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include - -namespace raft { -namespace distance { - -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/lap/lap.cuh b/cpp/include/raft/lap/lap.cuh index ca7d5e96a9..f7828294cd 100644 --- a/cpp/include/raft/lap/lap.cuh +++ b/cpp/include/raft/lap/lap.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the raft/solver version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the raft/solver version instead.") #include diff --git a/cpp/include/raft/lap/lap.hpp b/cpp/include/raft/lap/lap.hpp index 30f2b53e52..5472422053 100644 --- a/cpp/include/raft/lap/lap.hpp +++ b/cpp/include/raft/lap/lap.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,8 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the cuh version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the cuh version instead.") #include diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 608c63e1a9..c19f491319 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -216,7 +216,7 @@ void add_scalar(raft::device_resources const& handle, /** @} */ // end of group add -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index ed083a1590..88c49d1f42 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -82,7 +82,7 @@ void binary_op(raft::device_resources const& handle, InType in1, InType in2, Out /** @} */ // end of group binary_op -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 674be207d8..48c121c359 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -159,7 +159,7 @@ void coalesced_reduction(raft::device_resources const& handle, /** @} */ // end of group coalesced_reduction -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh index 4321e13d95..3b1e8c41c4 100644 --- a/cpp/include/raft/linalg/contractions.cuh +++ b/cpp/include/raft/linalg/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -100,7 +100,7 @@ struct KernelPolicy { SmemSize = 2 * SmemPage * sizeof(DataT), }; // enum -}; // struct KernelPolicy +}; // struct KernelPolicy template struct ColKernelPolicy { @@ -151,8 +151,7 @@ struct ColKernelPolicy { * @{ */ template -struct Policy4x4 { -}; +struct Policy4x4 {}; template struct Policy4x4 { @@ -174,8 +173,7 @@ struct Policy4x4 { * */ template -struct Policy4x4Skinny { -}; +struct Policy4x4Skinny {}; template struct Policy4x4Skinny { @@ -194,8 +192,7 @@ struct Policy4x4Skinny { * @{ */ template -struct Policy2x8 { -}; +struct Policy2x8 {}; template struct Policy2x8 { diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh new file mode 100644 index 0000000000..4800f2e3cf --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// The explicit instantiation of raft::linalg::detail::coalescedReduction is not +// forced because there would be too many instances. Instead, we cover the most +// common instantiations with extern template instantiations below. + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + extern template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh new file mode 100644 index 0000000000..5b01196cf4 --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -0,0 +1,368 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace linalg { +namespace detail { + +template +struct ReductionThinPolicy { + static constexpr int LogicalWarpSize = warpSize; + static constexpr int RowsPerBlock = rpb; + static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThinKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); + if (i >= N) return; + + OutType acc = init; + for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { + acc = reduce_op(acc, main_op(data[j + (D * i)], j)); + } + acc = raft::logicalWarpReduce(acc, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[i] = final_op(reduce_op(dots[i], acc)); + } else { + dots[i] = final_op(acc); + } + } +} + +template +void coalescedReductionThin(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); + dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); + dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); + coalescedReductionThinKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionThinDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + if (D <= IdxType(2)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(4)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(8)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (D <= IdxType(16)) { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThin>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +template +__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op, + FinalLambda final_op, + bool inplace = false) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = threadIdx.x; i < D; i += TPB) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { + if (inplace) { + dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); + } else { + dots[blockIdx.x] = final_op(acc); + } + } +} + +template +void coalescedReductionMedium(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); + coalescedReductionMediumKernel + <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +template +void coalescedReductionMediumDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: for now, this kernel is only used when D > 256. If this changes in the future, use + // smaller block sizes when relevant. + coalescedReductionMedium<256>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); +} + +template +struct ReductionThickPolicy { + static constexpr int ThreadsPerBlock = tpb; + static constexpr int BlocksPerRow = bpr; + static constexpr int BlockStride = tpb * bpr; +}; + +template +__global__ void __launch_bounds__(Policy::ThreadsPerBlock) + coalescedReductionThickKernel(OutType* buffer, + const InType* data, + IdxType D, + IdxType N, + OutType init, + MainLambda main_op, + ReduceLambda reduce_op) +{ + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + OutType thread_data = init; + IdxType rowStart = blockIdx.x * D; + for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; + i += Policy::BlockStride) { + IdxType idx = rowStart + i; + thread_data = reduce_op(thread_data, main_op(data[idx], i)); + } + OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); + if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } +} + +template +void coalescedReductionThick(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + common::nvtx::range fun_scope( + "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); + + dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); + dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); + + rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); + + /* We apply a two-step reduction: + * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It + * applies the main_op but not the final op. + * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any + * main_op but applies final_op. If in-place, the existing and new values are reduced. + */ + + coalescedReductionThickKernel + <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + + coalescedReductionThin(dots, + buffer.data(), + static_cast(ThickPolicy::BlocksPerRow), + N, + init, + stream, + inplace, + raft::identity_op(), + reduce_op, + final_op); +} + +template +void coalescedReductionThickDispatcher(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + // Note: multiple elements per thread to take advantage of the sequential reduction and loop + // unrolling + if (D < IdxType(32768)) { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionThick, ReductionThinPolicy<32, 4>>( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along +// rows for row major or reduce along columns for column major layout. Can do an inplace reduction +// adding to original values of dots if requested. +template +void coalescedReduction(OutType* dots, + const InType* data, + IdxType D, + IdxType N, + OutType init, + cudaStream_t stream, + bool inplace = false, + MainLambda main_op = raft::identity_op(), + ReduceLambda reduce_op = raft::add_op(), + FinalLambda final_op = raft::identity_op()) +{ + /* The primitive selects one of three implementations based on heuristics: + * - Thin: very efficient when D is small and/or N is large + * - Thick: used when N is very small and D very large + * - Medium: used when N is too small to fill the GPU with the thin kernel + */ + const IdxType numSMs = raft::getMultiProcessorCount(); + if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { + coalescedReductionThinDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else if (N < numSMs && D >= IdxType(16384)) { + coalescedReductionThickDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } else { + coalescedReductionMediumDispatcher( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + } +} + +} // namespace detail +} // namespace linalg +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 238e17fa56..3e6b17978b 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,353 +16,11 @@ #pragma once -#include -#include -#include -#include -#include +// Always include inline definitions of coalesced reduction, because we do not +// force explicit instantion. +#include "coalesced_reduction-inl.cuh" -namespace raft { -namespace linalg { -namespace detail { - -template -struct ReductionThinPolicy { - static constexpr int LogicalWarpSize = warpSize; - static constexpr int RowsPerBlock = rpb; - static constexpr int ThreadsPerBlock = LogicalWarpSize * RowsPerBlock; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThinKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - IdxType i = threadIdx.y + (Policy::RowsPerBlock * static_cast(blockIdx.x)); - if (i >= N) return; - - OutType acc = init; - for (IdxType j = threadIdx.x; j < D; j += Policy::LogicalWarpSize) { - acc = reduce_op(acc, main_op(data[j + (D * i)], j)); - } - acc = raft::logicalWarpReduce(acc, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[i] = final_op(reduce_op(dots[i], acc)); - } else { - dots[i] = final_op(acc); - } - } -} - -template -void coalescedReductionThin(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThin<%d,%d>", Policy::LogicalWarpSize, Policy::RowsPerBlock); - dim3 threads(Policy::LogicalWarpSize, Policy::RowsPerBlock, 1); - dim3 blocks(ceildiv(N, Policy::RowsPerBlock), 1, 1); - coalescedReductionThinKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionThinDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - if (D <= IdxType(2)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(4)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(8)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (D <= IdxType(16)) { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThin>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -template -__global__ void __launch_bounds__(TPB) coalescedReductionMediumKernel(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op, - FinalLambda final_op, - bool inplace = false) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = threadIdx.x; i < D; i += TPB) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { - if (inplace) { - dots[blockIdx.x] = final_op(reduce_op(dots[blockIdx.x], acc)); - } else { - dots[blockIdx.x] = final_op(acc); - } - } -} - -template -void coalescedReductionMedium(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope("coalescedReductionMedium<%d>", TPB); - coalescedReductionMediumKernel - <<>>(dots, data, D, N, init, main_op, reduce_op, final_op, inplace); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void coalescedReductionMediumDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: for now, this kernel is only used when D > 256. If this changes in the future, use - // smaller block sizes when relevant. - coalescedReductionMedium<256>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); -} - -template -struct ReductionThickPolicy { - static constexpr int ThreadsPerBlock = tpb; - static constexpr int BlocksPerRow = bpr; - static constexpr int BlockStride = tpb * bpr; -}; - -template -__global__ void __launch_bounds__(Policy::ThreadsPerBlock) - coalescedReductionThickKernel(OutType* buffer, - const InType* data, - IdxType D, - IdxType N, - OutType init, - MainLambda main_op, - ReduceLambda reduce_op) -{ - typedef cub::BlockReduce BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - OutType thread_data = init; - IdxType rowStart = blockIdx.x * D; - for (IdxType i = blockIdx.y * Policy::ThreadsPerBlock + threadIdx.x; i < D; - i += Policy::BlockStride) { - IdxType idx = rowStart + i; - thread_data = reduce_op(thread_data, main_op(data[idx], i)); - } - OutType acc = BlockReduce(temp_storage).Reduce(thread_data, reduce_op); - if (threadIdx.x == 0) { buffer[Policy::BlocksPerRow * blockIdx.x + blockIdx.y] = acc; } -} - -template -void coalescedReductionThick(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - common::nvtx::range fun_scope( - "coalescedReductionThick<%d,%d>", ThickPolicy::ThreadsPerBlock, ThickPolicy::BlocksPerRow); - - dim3 threads(ThickPolicy::ThreadsPerBlock, 1, 1); - dim3 blocks(N, ThickPolicy::BlocksPerRow, 1); - - rmm::device_uvector buffer(N * ThickPolicy::BlocksPerRow, stream); - - /* We apply a two-step reduction: - * 1. coalescedReductionThickKernel reduces the [N x D] input data to [N x BlocksPerRow]. It - * applies the main_op but not the final op. - * 2. coalescedReductionThinKernel reduces [N x BlocksPerRow] to [N x 1]. It doesn't apply any - * main_op but applies final_op. If in-place, the existing and new values are reduced. - */ - - coalescedReductionThickKernel - <<>>(buffer.data(), data, D, N, init, main_op, reduce_op); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - coalescedReductionThin(dots, - buffer.data(), - static_cast(ThickPolicy::BlocksPerRow), - N, - init, - stream, - inplace, - raft::identity_op(), - reduce_op, - final_op); -} - -template -void coalescedReductionThickDispatcher(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - // Note: multiple elements per thread to take advantage of the sequential reduction and loop - // unrolling - if (D < IdxType(32768)) { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionThick, ReductionThinPolicy<32, 4>>( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -// Primitive to perform reductions along the coalesced dimension of the matrix, i.e. reduce along -// rows for row major or reduce along columns for column major layout. Can do an inplace reduction -// adding to original values of dots if requested. -template -void coalescedReduction(OutType* dots, - const InType* data, - IdxType D, - IdxType N, - OutType init, - cudaStream_t stream, - bool inplace = false, - MainLambda main_op = raft::identity_op(), - ReduceLambda reduce_op = raft::add_op(), - FinalLambda final_op = raft::identity_op()) -{ - /* The primitive selects one of three implementations based on heuristics: - * - Thin: very efficient when D is small and/or N is large - * - Thick: used when N is very small and D very large - * - Medium: used when N is too small to fill the GPU with the thin kernel - */ - const IdxType numSMs = raft::getMultiProcessorCount(); - if (D <= IdxType(256) || N >= IdxType(4) * numSMs) { - coalescedReductionThinDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else if (N < numSMs && D >= IdxType(16384)) { - coalescedReductionThickDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } else { - coalescedReductionMediumDispatcher( - dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); - } -} - -} // namespace detail -} // namespace linalg -} // namespace raft \ No newline at end of file +// Do include the extern template instantiations when possible. +#ifdef RAFT_COMPILED +#include "coalesced_reduction-ext.cuh" +#endif diff --git a/cpp/include/raft/linalg/detail/cublas_wrappers.hpp b/cpp/include/raft/linalg/detail/cublas_wrappers.hpp index 87a195757c..5a7356a4c2 100644 --- a/cpp/include/raft/linalg/detail/cublas_wrappers.hpp +++ b/cpp/include/raft/linalg/detail/cublas_wrappers.hpp @@ -41,9 +41,9 @@ class cublas_device_pointer_mode { } } auto operator=(const cublas_device_pointer_mode&) -> cublas_device_pointer_mode& = delete; - auto operator=(cublas_device_pointer_mode&&) -> cublas_device_pointer_mode& = delete; - static auto operator new(std::size_t) -> void* = delete; - static auto operator new[](std::size_t) -> void* = delete; + auto operator=(cublas_device_pointer_mode&&) -> cublas_device_pointer_mode& = delete; + static auto operator new(std::size_t) -> void* = delete; + static auto operator new[](std::size_t) -> void* = delete; ~cublas_device_pointer_mode() { @@ -550,7 +550,7 @@ cublasStatus_t cublasgetrfBatched(cublasHandle_t handle, template <> inline cublasStatus_t cublasgetrfBatched(cublasHandle_t handle, // NOLINT int n, - float* const A[], // NOLINT + float* const A[], // NOLINT int lda, int* P, int* info, @@ -564,7 +564,7 @@ inline cublasStatus_t cublasgetrfBatched(cublasHandle_t handle, // NOLINT template <> inline cublasStatus_t cublasgetrfBatched(cublasHandle_t handle, // NOLINT int n, - double* const A[], // NOLINT + double* const A[], // NOLINT int lda, int* P, int* info, diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index 70bb2df4f5..c22ef09809 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -25,8 +25,7 @@ namespace raft { namespace linalg { namespace detail { -struct sum_tag { -}; +struct sum_tag {}; template __device__ void reduce(OutType* out, const InType acc, sum_tag) diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 4cba028d87..bc7c551d89 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -18,6 +18,8 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" +#include +#include #include #include #include @@ -42,10 +44,10 @@ namespace detail { */ template void qrGetQ_inplace( - raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) + raft::resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); - cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolver = resource::get_cusolver_dn_handle(handle); rmm::device_uvector tau(n_cols, stream); RAFT_CUDA_TRY(cudaMemsetAsync(tau.data(), 0, sizeof(math_t) * n_cols, stream)); @@ -83,7 +85,7 @@ void qrGetQ_inplace( } template -void qrGetQ(raft::device_resources const& handle, +void qrGetQ(raft::resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -95,7 +97,7 @@ void qrGetQ(raft::device_resources const& handle, } template -void qrGetQR(raft::device_resources const& handle, +void qrGetQR(raft::resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -103,7 +105,7 @@ void qrGetQR(raft::device_resources const& handle, int n_cols, cudaStream_t stream) { - cusolverDnHandle_t cusolverH = handle.get_cusolver_dn_handle(); + cusolverDnHandle_t cusolverH = resource::get_cusolver_dn_handle(handle); int m = n_rows, n = n_cols; rmm::device_uvector R_full(m * n, stream); diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index 05588bda9c..bbd71a4cf1 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -19,7 +19,9 @@ #include "cublas_wrappers.hpp" #include -#include +#include +#include +#include #include #include #include @@ -29,14 +31,14 @@ namespace linalg { namespace detail { template -void transpose(raft::device_resources const& handle, +void transpose(raft::resources const& handle, math_t* in, math_t* out, int n_rows, int n_cols, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + cublasHandle_t cublas_h = resource::get_cublas_handle(handle); RAFT_CUBLAS_TRY(cublasSetStream(cublas_h, stream)); int out_n_rows = n_cols; @@ -83,7 +85,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) template void transpose_row_major_impl( - raft::device_resources const& handle, + raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -92,7 +94,7 @@ void transpose_row_major_impl( T constexpr kOne = 1; T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle), CUBLAS_OP_T, CUBLAS_OP_N, out_n_cols, @@ -105,12 +107,12 @@ void transpose_row_major_impl( out.stride(0), out.data_handle(), out.stride(0), - handle.get_stream())); + resource::get_cuda_stream(handle))); } template void transpose_col_major_impl( - raft::device_resources const& handle, + raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -119,7 +121,7 @@ void transpose_col_major_impl( T constexpr kOne = 1; T constexpr kZero = 0; - CUBLAS_TRY(cublasgeam(handle.get_cublas_handle(), + CUBLAS_TRY(cublasgeam(resource::get_cublas_handle(handle), CUBLAS_OP_T, CUBLAS_OP_N, out_n_rows, @@ -132,7 +134,7 @@ void transpose_col_major_impl( out.stride(1), out.data_handle(), out.stride(1), - handle.get_stream())); + resource::get_cuda_stream(handle))); } }; // end namespace detail }; // end namespace linalg diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 0b18e6175c..428b9ba618 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -95,7 +95,7 @@ void divide_scalar(raft::device_resources const& handle, /** @} */ // end of group add -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 03e94a10b1..7829f8e49f 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -219,7 +219,7 @@ void eig_jacobi(raft::device_resources const& handle, /** @} */ // end of eig -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 96846003f6..019ec9f7ac 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -304,6 +304,6 @@ void gemv(raft::device_resources const& handle, } /** @} */ // end of gemv -}; // namespace linalg -}; // namespace raft +}; // namespace linalg +}; // namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/lanczos.cuh b/cpp/include/raft/linalg/lanczos.cuh index c9f3e0010e..04e9980583 100644 --- a/cpp/include/raft/linalg/lanczos.cuh +++ b/cpp/include/raft/linalg/lanczos.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,9 +24,9 @@ #pragma once -#pragma message(__FILE__ \ - " is deprecated and will be removed in a future release." \ - " Please use the sparse solvers version instead.") +#pragma message(__FILE__ \ + " is deprecated and will be removed in a future release." \ + " Please use the sparse solvers version instead.") #include diff --git a/cpp/include/raft/linalg/lstsq.cuh b/cpp/include/raft/linalg/lstsq.cuh index b36a9eba96..c753215737 100644 --- a/cpp/include/raft/linalg/lstsq.cuh +++ b/cpp/include/raft/linalg/lstsq.cuh @@ -244,7 +244,7 @@ void lstsq_qr(raft::device_resources const& handle, /** @} */ // end of lstsq -}; // namespace linalg -}; // namespace raft +}; // namespace linalg +}; // namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 59b2ca5ee5..6c65626ac5 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -238,7 +238,7 @@ void matrix_vector_op(raft::device_resources const& handle, /** @} */ // end of group matrix_vector_op -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index 62f4896d01..317c085673 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -74,7 +74,7 @@ void mean_squared_error(raft::device_resources const& handle, /** @} */ // end of group mean_squared_error -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index 574b88c63d..bdca641616 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -97,7 +97,7 @@ void multiply_scalar( /** @} */ // end of group multiply -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index 1fdfcb3780..057d6f6827 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -153,7 +153,7 @@ void power_scalar( /** @} */ // end of group add -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index 8e58af63c1..948996d0ac 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -19,6 +19,8 @@ #pragma once #include "detail/qr.cuh" +#include +#include namespace raft { namespace linalg { @@ -33,7 +35,7 @@ namespace linalg { * @param stream cuda stream */ template -void qrGetQ(raft::device_resources const& handle, +void qrGetQ(raft::resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -54,7 +56,7 @@ void qrGetQ(raft::device_resources const& handle, * @param stream cuda stream */ template -void qrGetQR(raft::device_resources const& handle, +void qrGetQR(raft::resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -77,13 +79,18 @@ void qrGetQR(raft::device_resources const& handle, * @param[out] Q Output raft::device_matrix_view */ template -void qr_get_q(raft::device_resources const& handle, +void qr_get_q(raft::resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q) { RAFT_EXPECTS(Q.size() == M.size(), "Size mismatch between Output and Input"); - qrGetQ(handle, M.data_handle(), Q.data_handle(), M.extent(0), M.extent(1), handle.get_stream()); + qrGetQ(handle, + M.data_handle(), + Q.data_handle(), + M.extent(0), + M.extent(1), + resource::get_cuda_stream(handle)); } /** @@ -94,7 +101,7 @@ void qr_get_q(raft::device_resources const& handle, * @param[out] R Output raft::device_matrix_view */ template -void qr_get_qr(raft::device_resources const& handle, +void qr_get_qr(raft::resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q, raft::device_matrix_view R) @@ -107,7 +114,7 @@ void qr_get_qr(raft::device_resources const& handle, R.data_handle(), M.extent(0), M.extent(1), - handle.get_stream()); + resource::get_cuda_stream(handle)); } /** @} */ diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index ae5457c44f..06f62f207e 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -161,7 +161,7 @@ void reduce(raft::device_resources const& handle, /** @} */ // end of group reduction -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 2b744d8134..71c8cf14a1 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -112,7 +112,7 @@ void reduce_cols_by_key( /** @} */ // end of group reduce_cols_by_key -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 484b60238b..0e83c9aa2b 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -191,7 +191,7 @@ void reduce_rows_by_key( /** @} */ // end of group reduce_rows_by_key -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index eb94547f13..8a32467873 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -765,7 +765,7 @@ void rsvd_perc_symmetric_jacobi(Args... args) /** @} */ // end of group rsvd -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index 55e661897d..eecc719617 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -83,7 +83,7 @@ void sqrt(raft::device_resources const& handle, InType in, OutType out) /** @} */ // end of group add -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index f58dfe28b3..25be368865 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -170,7 +170,7 @@ void strided_reduction(raft::device_resources const& handle, /** @} */ // end of group strided_reduction -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index da995b7a2a..cbd6b9df59 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -222,7 +222,7 @@ void subtract_scalar( /** @} */ // end of group subtract -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 4b78f2ef61..801d271fe9 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -415,7 +415,7 @@ void svd_reconstruction(raft::device_resources const& handle, /** @} */ // end of group svd -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif \ No newline at end of file diff --git a/cpp/include/raft/linalg/ternary_op.cuh b/cpp/include/raft/linalg/ternary_op.cuh index 1e347d69be..ce95e98499 100644 --- a/cpp/include/raft/linalg/ternary_op.cuh +++ b/cpp/include/raft/linalg/ternary_op.cuh @@ -83,7 +83,7 @@ void ternary_op( /** @} */ // end of group ternary_op -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index a0f418b4f7..0fe752347d 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -20,6 +20,7 @@ #include "detail/transpose.cuh" #include +#include namespace raft { namespace linalg { @@ -34,7 +35,7 @@ namespace linalg { * @param stream: cuda stream */ template -void transpose(raft::device_resources const& handle, +void transpose(raft::resources const& handle, math_t* in, math_t* out, int n_rows, @@ -76,7 +77,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) * @param[out] out Output matirx, storage is pre-allocated by caller. */ template -auto transpose(raft::device_resources const& handle, +auto transpose(raft::resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) -> std::enable_if_t, void> @@ -102,7 +103,7 @@ auto transpose(raft::device_resources const& handle, /** @} */ // end of group transpose -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index 23f932d2f2..58ff2f6bd6 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -124,7 +124,7 @@ void write_only_unary_op(const raft::device_resources& handle, OutType out, Lamb /** @} */ // end of group unary_op -}; // end namespace linalg -}; // end namespace raft +}; // end namespace linalg +}; // end namespace raft #endif diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index a4daf097e5..6546a48279 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -133,6 +133,6 @@ void sort_cols_per_row(Args... args) /** @} */ // end of group col_wise_sort -}; // end namespace raft::matrix +}; // end namespace raft::matrix #endif \ No newline at end of file diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..2b233c156d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uint32_t +#include // __half +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::matrix::detail { + +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +} // namespace raft::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +// We did not have these two for double before, but there are tests for them. We +// therefore include them here. +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/include/raft/matrix/detail/select_k-inl.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh new file mode 100644 index 0000000000..20c2fb119d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-inl.cuh @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "select_radix.cuh" +#include "select_warpsort.cuh" + +#include + +#include +#include + +namespace raft::matrix::detail { + +/** + * Select k smallest or largest key/values from each row in the input data. + * + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). + * + * @tparam T + * the type of the keys (what is being compared). + * @tparam IdxT + * the index type (what is being selected together with the keys). + * + * @param[in] in_val + * contiguous device array of inputs of size (len * batch_size); + * these are compared and selected. + * @param[in] in_idx + * contiguous device array of inputs of size (len * batch_size); + * typically, these are indices of the corresponding in_val. + * @param batch_size + * number of input rows, i.e. the batch size. + * @param len + * length of a single input array (row); also sometimes referred as n_cols. + * Invariant: len >= k. + * @param k + * the number of outputs to select in each input row. + * @param[out] out_val + * contiguous device array of outputs of size (k * batch_size); + * the k smallest/largest values from each row of the `in_val`. + * @param[out] out_idx + * contiguous device array of outputs of size (k * batch_size); + * the payload selected together with `out_val`. + * @param select_min + * whether to select k smallest (true) or largest (false) keys. + * @param stream + * @param mr an optional memory resource to use across the calls (you can provide a large enough + * memory pool here to avoid memory allocations within the call). + */ +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope( + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + // TODO (achirkin): investigate the trade-off for a wider variety of inputs. + const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + } else { + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); + } +} + +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index 20c2fb119d..711169984b 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -13,79 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include "select_radix.cuh" -#include "select_warpsort.cuh" - -#include - -#include -#include - -namespace raft::matrix::detail { - -/** - * Select k smallest or largest key/values from each row in the input data. - * - * If you think of the input data `in_val` as a row-major matrix with `len` columns and - * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills - * in the row-major matrix `out_val` of size (batch_size, k). - * - * @tparam T - * the type of the keys (what is being compared). - * @tparam IdxT - * the index type (what is being selected together with the keys). - * - * @param[in] in_val - * contiguous device array of inputs of size (len * batch_size); - * these are compared and selected. - * @param[in] in_idx - * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_val. - * @param batch_size - * number of input rows, i.e. the batch size. - * @param len - * length of a single input array (row); also sometimes referred as n_cols. - * Invariant: len >= k. - * @param k - * the number of outputs to select in each input row. - * @param[out] out_val - * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_val`. - * @param[out] out_idx - * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out_val`. - * @param select_min - * whether to select k smallest (true) or largest (false) keys. - * @param stream - * @param mr an optional memory resource to use across the calls (you can provide a large enough - * memory pool here to avoid memory allocations within the call). - */ -template -void select_k(const T* in_val, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out_val, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) -{ - common::nvtx::range fun_scope( - "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); - // TODO (achirkin): investigate the trade-off for a wider variety of inputs. - const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= select::warpsort::kMaxCapacity && !radix_faster) { - select::warpsort::select_k( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); - } else { - select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); - } -} +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "select_k-inl.cuh" +#endif -} // namespace raft::matrix::detail +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 7ac40ac0eb..b7d02d6b52 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -778,7 +778,7 @@ void radix_topk(const T* in, auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); + mem_req); } rmm::device_uvector> counters(max_chunk_size, stream, mr); @@ -1031,10 +1031,7 @@ void radix_topk_one_block(const T* in, max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + 256 * 4 // might need extra memory for alignment ); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); - } + if (pool_guard) { RAFT_LOG_DEBUG("radix::select_k: using pool memory resource"); } rmm::device_uvector buf1(len * max_chunk_size, stream, mr); rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index d362b73792..dc86a04733 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,7 +27,7 @@ #include #include -#include +#include #include /* @@ -870,8 +870,7 @@ struct launch_setup { }; template