From 1b0b5eac540b7f8fd19b18f1e6b8427c95503348 Mon Sep 17 00:00:00 2001 From: Anerudhan Gopal Date: Wed, 10 Apr 2024 17:48:57 +0000 Subject: [PATCH] cudnn frontend v1.3 release notes. (#72) [New API] Added new operations `sdpa_fp8_forward` and `sdpa_fp8_backward` to perform scaled dot prodcut attention of fp8 tensors. See more details in the `docs/operations/Attention.md` and cpp sample in `samples/cpp/mha.cpp`. Pybinds for the fp8 nodes are also added. [New API] Added new operation for resample forward operation. Add a new sample `samples/cpp/resample.cpp` to show its usage. [New API] Add a new API `deselect_engines(std::vector const &engine_names)` which blocks certain engine configs from running. [New API] Add new APIs `select_numeric_notes` and `select_behavior_notes` to allow user select engine configs which have the selected numeric and behavior notes respectively. [Python API] Added a custom exception `cudnnGraphNotSupportedException` to the python API to distinguish between graphs that are actually not supported as compared to programming errors. [Python API] Added a new `backend_version_string` which returns the backend version in canonical form (eg. 9.1.0) instead of a version number. [Bug Fix] Updated the workspace computation for sdpa fprop node. Previously, workspace was calculated for alibi slopes irrespective of whether alibi mask was turned on or not. [Bug Fix] Fixed deserialization of pass by values of half precision. --- CMakeLists.txt | 6 +- README.FE.1.0.md | 38 +- docs/operations/Attention.md | 351 +++++++++++-- docs/operations/Resampling.md | 50 ++ include/cudnn_frontend.h | 4 +- include/cudnn_frontend/graph_helpers.h | 10 + include/cudnn_frontend/graph_interface.h | 174 ++++++- include/cudnn_frontend/graph_properties.h | 363 +++++++++---- include/cudnn_frontend/node/matmul.h | 20 +- include/cudnn_frontend/node/matmul_fp8.h | 131 +++++ include/cudnn_frontend/node/pointwise.h | 33 +- include/cudnn_frontend/node/reduction.h | 16 +- include/cudnn_frontend/node/resample.h | 198 +++++++ include/cudnn_frontend/node/reshape.h | 2 +- include/cudnn_frontend/node/rng.h | 20 +- .../node/scaled_dot_product_flash_attention.h | 29 +- include/cudnn_frontend/node/sdpa_fp8.h | 263 ++++++++++ include/cudnn_frontend/node/sdpa_fp8_bwd.h | 345 +++++++++++++ include/cudnn_frontend/node/softmax.h | 41 +- include/cudnn_frontend/node_interface.h | 66 ++- include/cudnn_frontend/plans.h | 111 ++-- .../thirdparty/nlohmann/json.hpp | 417 +++++++++------ include/cudnn_frontend/utils/serialize.h | 54 +- include/cudnn_frontend_Resample.h | 1 + include/cudnn_frontend_shim.h | 27 + include/cudnn_frontend_utils.h | 67 ++- pyproject.toml | 2 +- python/cudnn/__init__.py | 4 +- python/properties.cpp | 5 +- python/pycudnn.cpp | 9 +- python/pygraph/pygraph.cpp | 2 + python/pygraph/pygraph.h | 66 ++- python/pygraph/sdpa.cpp | 255 ++++++++- requirements.txt | 3 +- samples/CMakeLists.txt | 3 + samples/cpp/matmuls.cpp | 4 +- samples/cpp/mha.cpp | 488 ++++++++++++++++++ samples/cpp/resample.cpp | 187 +++++++ samples/utils/error_util.h | 1 + test/python_fe/test_batchnorm.py | 7 +- test/python_fe/test_conv_bias.py | 3 +- test/python_fe/test_conv_genstats.py | 3 +- test/python_fe/test_instancenorm.py | 3 +- test/python_fe/test_layernorm.py | 3 +- test/python_fe/test_matmul_bias_relu.py | 11 +- test/python_fe/test_mhas.py | 63 +-- test/python_fe/test_rmsnorm.py | 3 +- test/python_fe/test_wgrads.py | 66 +-- test/unit_tests/CMakeLists.txt | 4 +- test/unit_tests/version.cpp | 33 ++ 50 files changed, 3466 insertions(+), 599 deletions(-) create mode 100644 docs/operations/Resampling.md create mode 100644 include/cudnn_frontend/node/matmul_fp8.h create mode 100644 include/cudnn_frontend/node/resample.h create mode 100644 include/cudnn_frontend/node/sdpa_fp8.h create mode 100644 include/cudnn_frontend/node/sdpa_fp8_bwd.h create mode 100644 samples/cpp/resample.cpp create mode 100644 test/unit_tests/version.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index b292979b..31e10f91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.17) -project(cudnn_frontend VERSION 1.2.1) +project(cudnn_frontend VERSION 1.3.0) option(CUDNN_FRONTEND_SKIP_NLOHMANN_JSON "Defines whether FE should not include nlohmann/json.hpp." OFF) option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON) @@ -39,6 +39,10 @@ target_link_libraries( target_compile_features(cudnn_frontend INTERFACE cxx_std_17) +# Make PCH for targets to link against +add_library(_cudnn_frontend_pch INTERFACE) +target_precompile_headers(_cudnn_frontend_pch INTERFACE ${PROJECT_SOURCE_DIR}/include/cudnn_frontend.h) + if (CUDNN_FRONTEND_BUILD_SAMPLES) add_subdirectory(samples) endif() diff --git a/README.FE.1.0.md b/README.FE.1.0.md index b337980b..e23e491e 100644 --- a/README.FE.1.0.md +++ b/README.FE.1.0.md @@ -29,21 +29,24 @@ The steps involved in building and running a cudnn graph are as follows: ## APIs FE v1.0 API follows a functional style of building a graph. Operations take in input tensors and return output tensors. This also allows composition of operations. -| Purpose | C++ API | Python API | -| --- | --- | --- | -| Create tensor | tensor | tensor | -| [Convolution Fprop](docs/operations/Convolutions.md) | conv_fprop
Conv_fprop_attributes | conv_fprop | -| [Convolution Dgrad](docs/operations/Convolutions.md) | conv_dgrad
Conv_dgrad_attributes | conv_dgrad | -| [Convolution Wgrad](docs/operations/Convolutions.md) | conv_wgrad
Conv_wgrad_attributes | conv_wgrad | -| [Matrix Multiplication](docs/operations/Matmul.md) | matmul
Matmul_attributes | matmul | -| [Pointwise Operations](docs/operations/Pointwise.md) | pointwise
Pointwise_attributes | - add
- bias
- rqsrt
- sub
- mul
- scale
- relu
- elu
- gelu
- cmp_gt | -| [Batch Normalization](docs/operations/Normalizations.md) | batchnorm
Batchnorm_attributes | batchnorm | -| [Batch Norm bprop](docs/operations/Normalizations.md) | batchnorm_backward
Batchnorm_backward_attributes | batchnorm_backward | -| Generate stats of output| genstats
Genstats_attributes | genstats | -| BN Finalize of stats | bn_finalize
BN_finalize_attributes | bn_finalize | -| Dbn weight | dbn_weight
DBN_weight_attributes | dbn_weight | -| [Scale dot product attention](docs/operations/Attention.md) | sdpa
SDPA_attributes | sdpa | -| [Scale dot product attention backward](docs/operations/Attention.md) | sdpa_backward
SDPA_backward_attributes | sdpa_backward | +| Purpose | C++ API | Python API | +|--------------------------------------------------------------------------|------------------------------------------------------|--------------------------------------------------------------------------------------------------| +| Create tensor | tensor | tensor | +| [Convolution Fprop](docs/operations/Convolutions.md) | conv_fprop
Conv_fprop_attributes | conv_fprop | +| [Convolution Dgrad](docs/operations/Convolutions.md) | conv_dgrad
Conv_dgrad_attributes | conv_dgrad | +| [Convolution Wgrad](docs/operations/Convolutions.md) | conv_wgrad
Conv_wgrad_attributes | conv_wgrad | +| [Matrix Multiplication](docs/operations/Matmul.md) | matmul
Matmul_attributes | matmul | +| [Pointwise Operations](docs/operations/Pointwise.md) | pointwise
Pointwise_attributes | - add
- bias
- rqsrt
- sub
- mul
- scale
- relu
- elu
- gelu
- cmp_gt | +| [Batch Normalization](docs/operations/Normalizations.md) | batchnorm
Batchnorm_attributes | batchnorm | +| [Batch Norm bprop](docs/operations/Normalizations.md) | batchnorm_backward
Batchnorm_backward_attributes | batchnorm_backward | +| Generate stats of output | genstats
Genstats_attributes | genstats | +| BN Finalize of stats | bn_finalize
BN_finalize_attributes | bn_finalize | +| Dbn weight | dbn_weight
DBN_weight_attributes | dbn_weight | +| [Resampling](docs/operations/Resampling.md) | resample
Resample_attributes | resample | +| [Scale dot product attention](docs/operations/Attention.md) | sdpa
SDPA_attributes | sdpa | +| [Scale dot product attention backward](docs/operations/Attention.md) | sdpa_backward
SDPA_backward_attributes | sdpa_backward | +| [Scale dot product attention FP8](docs/operations/Attention.md) | sdpa_fp8
SDPA_fp8_attributes | sdpa_fp8 | +| [Scale dot product attention backward FP8](docs/operations/Attention.md) | sdpa_fp8_backward
SDPA_fp8_backward_attributes | sdpa_fp8_backward | ### Create Graph Instantiate an object of class `cudnn_frontend::graph::Graph` which will house tensors and operations. @@ -141,9 +144,12 @@ cudnn_frontend::Graph::build_plan_at_index( ### Filter plans (optional) -Users can filter out plans against numerical, behavioral notes, or plans that do not provide desired functional correctness. +Users can filter plans on numerical, behavioral notes, or plans that do not provide desired functional correctness. ``` +cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::select_numeric_notes(std::vector const&); +cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::select_behavior_notes(std::vector const&); + cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_numeric_notes(std::vector const&); cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_behavior_notes(std::vector const&); cudnn_frontend::graph::Graph& cudnn_frontend::graph::Plans::deselect_workspace_greater_than(int64_t const workspace); diff --git a/docs/operations/Attention.md b/docs/operations/Attention.md index e7e2a0a5..a40b993e 100644 --- a/docs/operations/Attention.md +++ b/docs/operations/Attention.md @@ -1,11 +1,13 @@ ## Table of Contents 1. [Scaled Dot Product Attention](#scaled-dot-product-attention) 2. [Scaled Dot Product Attention Backward](#scaled-dot-product-attention-backward) -3. Appendices +3. [Scaled Dot Product Attention FP8](#scaled-dot-product-attention-fp8) +4. [Scaled Dot Product Attention Backward FP8](#scaled-dot-product-attention-backward-fp8) +5. Appendices - [Tensor Layouts](#appendix-a) - [Workspace limits and Performance](#appendix-b) - [RNG dump](#appendix-c) -4. [Miscellaneous](#miscellaneous) +6. [Miscellaneous](#miscellaneous) ### Scaled Dot Product Attention @@ -15,6 +17,12 @@ $\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation. +- Python sample: [samples/python/50_scaled_dot_product_attention.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/50_scaled_dot_product_attention.ipynb) + +- C++ sample: [samples/cpp/mha.cpp](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/cpp/mha.cpp) + +- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) + #### Configurable Options: - Attention scale (`attn_scale`): Applies a scaling factor to attention scores before the softmax, such as $\frac{1}{\sqrt{\text{d}}}$. Set to 1.0 by default. @@ -33,8 +41,6 @@ using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2 - `dropout scale` used to adjust the scale of the remaining weights accordingly, such as $1 / (1 - \text{dropout probability})$. - Ragged tensor: allows the query, key, value, and output tensor to be [ragged tensors](https://www.tensorflow.org/guide/ragged_tensor), which are tensors with nested variable length lists as inner dimensions. Users must pass another tensor called ragged offset tensor using the `Tensor_attributes.set_ragged_offset()` method as specified in the tensors section below. -When multiple masking options are enabled, they are applied in the listed order above. - #### Tensors: - Query tensor should have dimensions $(B, H_{q}, S_{q}, D_{qk})$ with input/output datatype. @@ -74,9 +80,10 @@ Where, - the stride of the embedding dimension per head $D_{qk}$ and $D_{v}$ for all the tensors above must be 1. - this operation is only supported on GPUs with NVIDIA Ampere architecture (SM80) or newer. -**API:** +#### C++ API: ```cpp +// returns [output, softmax_stats] std::array, 2> sdpa(std::shared_ptr q, std::shared_ptr k, @@ -84,21 +91,19 @@ sdpa(std::shared_ptr q, SDPA_attributes options); ``` -The function returns an array of two tensors: `[output, softmax_stats]`. - The `options` parameter of type `SDPA_attributes` is used to control the attributes of the forward operation, as detailed below: ```cpp -SDPA_attributes & +SDPA_attributes& set_is_inference(bool const value); -SDPA_attributes & +SDPA_attributes& set_attn_scale(std::shared_ptr value); SDPA_attributes& set_attn_scale(float const value); -SDPA_attributes & +SDPA_attributes& set_bias(std::shared_ptr value); SDPA_attributes& @@ -113,23 +118,23 @@ set_seq_len_q(std::shared_ptr value); SDPA_attributes& set_seq_len_kv(std::shared_ptr value); -SDPA_attributes & +SDPA_attributes& set_causal_mask(bool const value); -SDPA_attributes & +SDPA_attributes& set_dropout(float const probability, std::shared_ptr seed, std::shared_ptr offset); -SDPA_attributes & +SDPA_attributes& set_dropout(std::shared_ptr mask, std::shared_ptr scale); -SDPA_attributes & +SDPA_attributes& set_compute_data_type(DataType_t value); ``` -**Python API:** +#### Python API: ``` Args: @@ -139,17 +144,17 @@ Args: is_inference (bool): Whether it is an inference step or training step. attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. + use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. - use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. Returns: - o (cudnn_tensor): The result of scaled dot-product attention. + o (cudnn_tensor): The output data. stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. ``` @@ -157,6 +162,12 @@ Returns: This operation computes gradient tensors for scaled dot product attention using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. +- Python sample: [samples/python/51_scaled_dot_product_attention_backward.ipynb](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/python/51_scaled_dot_product_attention_backward.ipynb) + +- C++ sample: [samples/cpp/mha.cpp](https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/cpp/mha.cpp) + +- Python tests: [test/python_fe/test_mhas.py](https://github.com/NVIDIA/cudnn-frontend/blob/main/test/python_fe/test_mhas.py) + #### Configurable Options: All the options mentioned in the forward operation, including ragged tensors and GQA/MQA, are applicable in the backward operation as well. @@ -169,8 +180,9 @@ All the tensor requirements described in the forward operation are applicable in All the limitations mentioned in the forward operation are applicable in the backward operation as well. -#### API: +#### C++ API: ```cpp +// returns [dQ, dK, dV] std::array, 3> sdpa_backward(std::shared_ptr q, std::shared_ptr k, @@ -181,8 +193,6 @@ sdpa_backward(std::shared_ptr q, SDPA_backward_attributes); ``` -The function returns an array of three tensors: `[dQ, dK, dV]`. - The `options` parameter of type `SDPA_backward_attributes` is used to control the attributes of backward operation, as detailed below: ```cpp @@ -227,7 +237,7 @@ SDPA_backward_attributes& set_compute_data_type(DataType_t const value); ``` -Python API: +#### Python API: ``` Args: @@ -241,21 +251,302 @@ Args: bias (Optional[cudnn_tensor]): The bias data for attention. Default is None. dBias (Optional[cudnn_tensor]): The dBias output for attention. Default is None. use_alibi_mask (Optional[bool]): Whether to use alibi mask. Default is False. + use_padding_mask (Optional[bool]): Whether to use padding mask. Default is False. + seq_len_q (Optional[cudnn_tensor]): The sequence length of the query. + seq_len_kv (Optional[cudnn_tensor]): The sequence length of the key. use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. - dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], - Tuple[mask: cudnn_tensor, scale: cudnn_tensor, scale_inv: cudnn_tensor]]]): - Whether to do dropout. Default is None. + dropout (Optional[Union[Tuple[(probability: float, seed: cudnn_tensor, offset: cudnn_tensor)], Tuple[mask: cudnn_tensor, scale: cudnn_tensor]]]): Whether to do dropout. Default is None. compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. name (Optional[str]): The name of the operation. Returns: - dQ (cudnn_tensor): The query gradient tensor of scaled dot-product attention. - dK (cudnn_tensor): The key gradient tensor of scaled dot-product attention. - dV (cudnn_tensor): The value gradient tensor of scaled dot-product attention. + dQ (cudnn_tensor): The query gradient data. + dK (cudnn_tensor): The key gradient data. + dV (cudnn_tensor): The value gradient data. +``` + +### Scaled Dot Product Attention FP8 + +This operation computes the scaled dot product attention in the FP8 (8-bit floating point) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). It is applicable for both training and inference phases, with an option to generate a stats tensor to be used for backwards training computation. + +The FP8 datatype consists of two encodings: +- `FP8_E4M3` (1 sign bit, 4 exponent bits, and 3 mantissa bits) +- `FP8_E5M2` (1 sign bit, 5 exponent bits, 2 mantissa bits). + +Due to the limited numerical precision of FP8 data type, for practical use cases, users must scale values computed in FP32 format before storing them in FP8 format, and descale the values stored in FP8 format before performing computations on them. For more information, refer to [the Transformer Engine FP8 Primer](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html). + +The suggested value for the scaling factor is computed as: (Max representable value in the fp8 format) / (Max absolute value seen in the tensor for the previous layer). +- For E4M3, the suggested scaling factor is `448.f/ prev_layer_tensor_amax` (rounded to the nearest lower power of two) +- For E5M2, the suggested scaling factor is `57344.f/ prev_layer_tensor_amax` (rounded to the nearest lower power of two) + +The suggested value for the descale factor is the reciprocal of the scale factor. + +Since scaling and descaling are critical for convergence with FP8 datatype, users are required to pass scaling and descaling input tensors, as well as amax output tensors. + +#### Configurable Options + +The current FP8 support is a subset of the options supported in FP16 and BF16 support. We are actively working on expanding the support for FP8. +- Attention scale (`attn_scale`): Applies a scaling factor to attention scores before the softmax, such as $\frac{1}{\sqrt{\text{d}}}$. Set to 1.0 by default. +- Causal mask: Fills the upper triangular matrix of attention scores with negative infinity. + +#### Tensors + +The tensors in forward operation are defined as the following: + +$P = QK^T$ + +$S = \text{softmax}(P)$ + +$O = SV$ + +##### Input Tensors + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|--------------|------------------------------| +| Q | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{qk})$ | +| K | GPU | E4M3 or E5M2 | $(B, H_{k}, S_{kv}, D_{qk})$ | +| V | GPU | E4M3 or E5M2 | $(B, H_{v}, S_{kv}, D_{v})$ | +| Descale Q | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale K | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale V | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale S | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale S | GPU | FP32 | $(1, 1, 1, 1)$ | + +##### Output Tensors + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|--------------|------------------------------| +| O | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{v})$ | +| Stats (training only) | GPU | FP32 | $(B, H_{q}, S_{q}, 1)$ | +| AMax S | GPU | FP32 | $(1, 1, 1, 1)$ | +| AMax O | GPU | FP32 | $(1, 1, 1, 1)$ | + +Where, + +- $B$ is the batch size +- $H_{q}$ is the number of query heads +- $H_{k}$ is the number of key heads +- $H_{v}$ is the number of value heads +- $S_{q}$ is the sequence length of the query +- $S_{kv}$ is the sequence length of the key and value +- $D_{qk}$ is the embedding dimension per head of query and key +- $D_{v}$ is the embedding dimension per head of value + +#### Group-query attention (GQA) and Multi-query attention (MQA) + +- As described in the paper [GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints](https://arxiv.org/abs/2305.13245), +- When $H_{k}$ and $H_{v}$ is less than $H_{q}$ and factors of $H_{q}$, this operation will perform group-query attention (GQA) computation. +- When $H_{k}$ and $H_{v}$ are both set to 1, this operation perform multi-query attention (MQA) computation. + +#### Limitations: +- The dimension of the embedding dimension per head $D_{qk}$ and $D_{v}$ must be a multiple of 8 with maximum value 128. +- the stride of the embedding dimension per head $D_{qk}$ and $D_{v}$ for all the tensors above must be 1. +- this operation is only supported on GPUs with NVIDIA Hopper architecture (SM90) or newer. + +#### C++ API: +```cpp +// returns [o, stats, amax_s, amax_o] +std::array, 4> +Graph::sdpa_fp8(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_s, + std::shared_ptr scale_s, + std::shared_ptr scale_o, + SDPA_fp8_attributes attributes); ``` +The `options` parameter of type `SDPA_fp8_attributes` is used to control the attributes of the forward operation, as detailed below: + + +```cpp +SDPA_fp8_attributes& +set_is_inference(bool const value); + +SDPA_fp8_attributes& +set_attn_scale(std::shared_ptr value); + +SDPA_fp8_attributes& +set_attn_scale(float const value); + +SDPA_fp8_attributes& +set_causal_mask(bool const value); +``` + +#### Python API: +``` +Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + descale_q (cudnn_tensor): Descale factor for query. + descale_k (cudnn_tensor): Descale factor for key. + descale_v (cudnn_tensor): Descale factor for value. + descale_s (cudnn_tensor): Descale factor for S tensor. + scale_s (cudnn_tensor): Scale factor for S tensor. + scale_o (cudnn_tensor): Scale factor for output. + is_inference (bool): Whether it is an inference step or training step. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + +Returns: + o (cudnn_tensor): The output data. + stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. + amax_s (cudnn_tensor): The absolute maximum of S tensor. + amax_o (cudnn_tensor): The absolute maximum of output tensor. +``` + +### Scaled Dot Product Attention Backward FP8 + +This operation computes the gradients for scaled dot product attention in the FP8 (8-bit floating point) datatype, using the FlashAttention-2 algorithm as described in the paper [FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning](https://arxiv.org/abs/2307.08691). The user is required to pass the stats tensor from the forward operation to the backward operation as input. + +#### Configurable Options: + +All the options mentioned in the forward FP8 operation, including ragged tensors and GQA/MQA, are applicable in the backward operation as well. + +#### Tensors + +The tensors in backward operation are defined as the following: + +$dV = S^TdO$ + +$dS = dOV^T$ + +$dP = \text{dSoftmax}(dS)$ + +$dQ = dPK$ + +$dK = QdP$ + +##### Input Tensors + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|--------------|------------------------------| +| Q | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{qk})$ | +| K | GPU | E4M3 or E5M2 | $(B, H_{k}, S_{kv}, D_{qk})$ | +| V | GPU | E4M3 or E5M2 | $(B, H_{v}, S_{kv}, D_{v})$ | +| O | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{v})$ | +| dO | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{v})$ | +| Descale Q | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale K | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale V | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale O | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale dO | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale S | GPU | FP32 | $(1, 1, 1, 1)$ | +| Descale dP | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale S | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale dQ | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale dK | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale dV | GPU | FP32 | $(1, 1, 1, 1)$ | +| Scale dP | GPU | FP32 | $(1, 1, 1, 1)$ | + +##### Output Tensors + +| Tensor Name | Device | Data Type | Dimensions | +|-----------------------|------------|--------------|------------------------------| +| dQ | GPU | E4M3 or E5M2 | $(B, H_{q}, S_{q}, D_{qk})$ | +| dK | GPU | E4M3 or E5M2 | $(B, H_{k}, S_{kv}, D_{qk})$ | +| dV | GPU | E4M3 or E5M2 | $(B, H_{v}, S_{kv}, D_{v})$ | +| Amax dQ | GPU | FP32 | $(1, 1, 1, 1)$ | +| Amax dK | GPU | FP32 | $(1, 1, 1, 1)$ | +| Amax dV | GPU | FP32 | $(1, 1, 1, 1)$ | +| Amax dP | GPU | FP32 | $(1, 1, 1, 1)$ | + +Where, + +- $B$ is the batch size +- $H_{q}$ is the number of query heads +- $H_{k}$ is the number of key heads +- $H_{v}$ is the number of value heads +- $S_{q}$ is the sequence length of the query +- $S_{kv}$ is the sequence length of the key and value +- $D_{qk}$ is the embedding dimension per head of query and key +- $D_{v}$ is the embedding dimension per head of value + +#### Limitations: +All the limitations mentioned in the forward operation are applicable in the backward operation as well. + +#### C++ API: +```cpp +// returns [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] +std::array, 7> +Graph::sdpa_fp8_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr Stats, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_o, + std::shared_ptr descale_do, + std::shared_ptr descale_s, + std::shared_ptr descale_dp, + std::shared_ptr scale_s, + std::shared_ptr scale_dq, + std::shared_ptr scale_dk, + std::shared_ptr scale_dv, + std::shared_ptr scale_dp, + SDPA_fp8_backward_attributes attributes); +``` + +The `options` parameter of type `SDPA_fp8_backward_attributes` is used to control the attributes of the forward operation, as detailed below: + + +``` +SDPA_fp8_backward_attributes& +set_attn_scale(std::shared_ptr value); + +SDPA_fp8_backward_attributes& +set_attn_scale(float const value); + +SDPA_fp8_backward_attributes& +set_causal_mask(bool const value); +``` + +#### Python API: +``` +Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + o (cudnn_tensor): The output data. + dO (cudnn_tensor): The output gradient data. + stats (cudnn_tensor): The softmax statistics in case the operation is in a training step. + descale_q (cudnn_tensor): Descale factor for query. + descale_k (cudnn_tensor): Descale factor for key. + descale_v (cudnn_tensor): Descale factor for value. + descale_o (cudnn_tensor): Descale factor for output. + descale_dO (cudnn_tensor): Descale factor for output gradient. + descale_s (cudnn_tensor): Descale factor for S tensor. + descale_dP (cudnn_tensor): Descale factor for P gradient tensor. + scale_s (cudnn_tensor): Scale factor for S tensor. + scale_dQ (cudnn_tensor): Scale factor for query gradient. + scale_dK (cudnn_tensor): Scale factor for key gradient. + scale_dV (cudnn_tensor): Scale factor for value gradient. + scale_dP (cudnn_tensor): Scale factor for dP gradient. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + +Returns: + dQ (cudnn_tensor): The query gradient data. + dK (cudnn_tensor): The key gradient data. + dV (cudnn_tensor): The value gradient data. + amax_dQ (cudnn_tensor): The absolute maximum of query gradient tensor. + amax_dK (cudnn_tensor): The absolute maximum of key gradient tensor. + amax_dV (cudnn_tensor): The absolute maximum of value gradient tensor. + amax_dP (cudnn_tensor): The absolute maximum of dP tensor. +``` -#### Appendix A +### Appendix A Tensor Layouts: Q, K, V, O and corresponding gradients layout support. cuDNN API expresses the layout of tensors based on strides. @@ -310,7 +601,7 @@ Below we will go through the standard usage of the attention tensors and how the Ragged offset is insufficient to represent this. This case is NOT supported. -#### Appendix B +### Appendix B Workspace limit: Scaled Dot Product Attention Backward improves performance by using an optional dP workspace tensor. This tensor's memory consumption increases quadratically with the sequence length. The following describes the behavior of the `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT` environment variable, which allows the user to change the GPU memory limit for this workspace tensor: - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = unset` @@ -322,7 +613,7 @@ Scaled Dot Product Attention Backward improves performance by using an optional - `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT = n` Allows workspace optimization up to a user-defined limit of n bytes, accommodating systems with varying GPU memory capacities. -#### Appendix C +### Appendix C To dump the dropout mask generated by the Philox RNG dropout implementation for debugging purposes, users can use the `rng_dump` option. This option requires users to pass a tensor of dimensions $(B, H_{q}, S_{q}, S_{kv})$ ### Miscellaneous diff --git a/docs/operations/Resampling.md b/docs/operations/Resampling.md new file mode 100644 index 00000000..0e70209d --- /dev/null +++ b/docs/operations/Resampling.md @@ -0,0 +1,50 @@ + +## Table of Contents +1. [Resampling Forward](#Resampling_Forward) +2. [Resampling Backward](#Resampling_Backward) + +### Resampling Forward +The resample operation represents the resampling of the spatial dimensions of an image to a desired value. + +The output array contains two tensors: +1. The resampled output tensor. +2. The computed index tensor. + +NOTE: Index tensor is only outputted in training mode of max pooling. It can be fed to backward pass for faster performance. + +#### Resample Attributes + +The Resample_attributes class is used to configure the resampling operation. It provides the following setters: + +``` +# The resampling mode, such as average pooling, max pooling, bi-linear, or cubic. +auto set_resampling_mode(ResampleMode_t const& value) -> Resample_attributes&; + +# The padding mode, such as zero or neg infinity. +auto set_padding_mode(PaddingMode_t const& value) -> Resample_attributes&; + +# The window size to be used for the resampling operation. +auto set_window(std::vector const& value) -> Resample_attributes&; +auto set_window(std::vector const& value) -> Resample_attributes&; + +# The stride values to be used for the resampling operation. +auto set_stride(std::vector const& value) -> Resample_attributes&; +auto set_stride(std::vector const& value) -> Resample_attributes&; + +# The padding values to be applied before and after the resampling input. +auto set_pre_padding(std::vector const& value) -> Resample_attributes&; +auto set_pre_padding(std::vector const& value) -> Resample_attributes&; +auto set_post_padding(std::vector const& value) -> Resample_attributes&; +auto set_post_padding(std::vector const& value) -> Resample_attributes&; + +# A flag indicating whether the resampling is being performed during inference. +auto set_is_inference(bool const value) -> Resample_attributes&; +``` + +cudnn backend develop guide on resampling forward contains more information on exact support surface across different versions. Please refer to it's [Resampling Forward](https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html#resamplefwd) section for more details. + +Python API for resampling forward will be supproted soon. + +### Resampling Backward +To be supported soon. + diff --git a/include/cudnn_frontend.h b/include/cudnn_frontend.h index c5845459..15908f5e 100644 --- a/include/cudnn_frontend.h +++ b/include/cudnn_frontend.h @@ -124,8 +124,8 @@ #include "cudnn_frontend/utils/serialize.h" #define CUDNN_FRONTEND_MAJOR_VERSION 1 -#define CUDNN_FRONTEND_MINOR_VERSION 2 -#define CUDNN_FRONTEND_PATCH_VERSION 1 +#define CUDNN_FRONTEND_MINOR_VERSION 3 +#define CUDNN_FRONTEND_PATCH_VERSION 0 #define CUDNN_FRONTEND_VERSION \ ((CUDNN_FRONTEND_MAJOR_VERSION * 10000) + (CUDNN_FRONTEND_MINOR_VERSION * 100) + CUDNN_FRONTEND_PATCH_VERSION) diff --git a/include/cudnn_frontend/graph_helpers.h b/include/cudnn_frontend/graph_helpers.h index 34a437f9..93d45aa6 100644 --- a/include/cudnn_frontend/graph_helpers.h +++ b/include/cudnn_frontend/graph_helpers.h @@ -236,4 +236,14 @@ generate_column_major_stride_order(int64_t const num_dims) { } // namespace detail +class cudnnGraphNotSupportedException : public std::runtime_error { + public: + cudnnGraphNotSupportedException(const char* message) throw() : std::runtime_error(message) {} + + virtual const char* + what() const throw() { + return std::runtime_error::what(); + } +}; + } // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/graph_interface.h b/include/cudnn_frontend/graph_interface.h index ee2a1521..1be1b507 100644 --- a/include/cudnn_frontend/graph_interface.h +++ b/include/cudnn_frontend/graph_interface.h @@ -15,9 +15,12 @@ #include "node/layernorm.h" #include "node/instancenorm.h" #include "node/rmsnorm.h" +#include "node/resample.h" #include "node/reshape.h" // #include "node/scaled_dot_product_attention.h" #include "node/scaled_dot_product_flash_attention.h" +#include "node/sdpa_fp8.h" +#include "node/sdpa_fp8_bwd.h" #include "plans.h" #include "graph_helpers.h" @@ -188,6 +191,38 @@ class Graph : public INode { std::shared_ptr, std::shared_ptr, SDPA_attributes); + + std::array, 4> sdpa_fp8(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_fp8_attributes); + + inline std::array, 7> sdpa_fp8_backward(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + SDPA_fp8_backward_attributes); + std::array, 3> sdpa_backward(std::shared_ptr, std::shared_ptr, std::shared_ptr, @@ -244,10 +279,40 @@ class Graph : public INode { return *this; } + Graph & + deselect_engines(std::vector const &engine_names) { + for (auto &plan_list : plans) { + plan_list.set_barred_names(engine_names); + } + return *this; + } + + Graph & + select_behavior_notes(std::vector const ¬es) { + for (auto &plan_list : plans) { + auto status = plan_list.filter_behavior_notes(notes, true); + if (status.is_bad()) { + getLogger() << status.get_message() << std::endl; + } + } + return *this; + } + + Graph & + select_numeric_notes(std::vector const ¬es) { + for (auto &plan_list : plans) { + auto status = plan_list.filter_numeric_notes(notes, true); + if (status.is_bad()) { + getLogger() << status.get_message() << std::endl; + } + } + return *this; + } + Graph & deselect_behavior_notes(std::vector const ¬es) { for (auto &plan_list : plans) { - auto status = plan_list.deselect_behavior_notes(notes); + auto status = plan_list.filter_behavior_notes(notes, false); if (status.is_bad()) { getLogger() << status.get_message() << std::endl; } @@ -258,7 +323,7 @@ class Graph : public INode { Graph & deselect_numeric_notes(std::vector const ¬es) { for (auto &plan_list : plans) { - auto status = plan_list.deselect_numeric_notes(notes); + auto status = plan_list.filter_numeric_notes(notes, false); if (status.is_bad()) { getLogger() << status.get_message() << std::endl; } @@ -816,6 +881,111 @@ Graph::sdpa(std::shared_ptr q, return {O, Stats}; } +inline std::array, 4> +Graph::sdpa_fp8(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_s, + std::shared_ptr scale_s, + std::shared_ptr scale_o, + SDPA_fp8_attributes attributes) { + // Make required output tensors + auto O = attributes.outputs[SDPA_fp8_attributes::output_names::O] = output_tensor(attributes.name + "::O"); + + std::shared_ptr Stats = nullptr; + if (attributes.is_inference == false) { + Stats = attributes.outputs[SDPA_fp8_attributes::output_names::Stats] = + output_tensor(attributes.name + "::Stats"); + } + + auto Amax_S = attributes.outputs[SDPA_fp8_attributes::output_names::Amax_S] = + output_tensor(attributes.name + "::Amax_S"); + auto Amax_O = attributes.outputs[SDPA_fp8_attributes::output_names::Amax_O] = + output_tensor(attributes.name + "::Amax_O"); + + // Set inputs + attributes.inputs[SDPA_fp8_attributes::input_names::Q] = q; + attributes.inputs[SDPA_fp8_attributes::input_names::K] = k; + attributes.inputs[SDPA_fp8_attributes::input_names::V] = v; + + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_Q] = descale_q; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_K] = descale_k; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_V] = descale_v; + attributes.inputs[SDPA_fp8_attributes::input_names::Descale_S] = descale_s; + attributes.inputs[SDPA_fp8_attributes::input_names::Scale_S] = scale_s; + attributes.inputs[SDPA_fp8_attributes::input_names::Scale_O] = scale_o; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {O, Stats, Amax_S, Amax_O}; +} + +inline std::array, 7> +Graph::sdpa_fp8_backward(std::shared_ptr q, + std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr o, + std::shared_ptr dO, + std::shared_ptr Stats, + std::shared_ptr descale_q, + std::shared_ptr descale_k, + std::shared_ptr descale_v, + std::shared_ptr descale_o, + std::shared_ptr descale_do, + std::shared_ptr descale_s, + std::shared_ptr descale_dp, + std::shared_ptr scale_s, + std::shared_ptr scale_dq, + std::shared_ptr scale_dk, + std::shared_ptr scale_dv, + std::shared_ptr scale_dp, + SDPA_fp8_backward_attributes attributes) { + // Make required output tensors + auto dQ = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dQ] = + output_tensor(attributes.name + "::dQ"); + auto dK = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dK] = + output_tensor(attributes.name + "::dK"); + auto dV = attributes.outputs[SDPA_fp8_backward_attributes::output_names::dV] = + output_tensor(attributes.name + "::dV"); + auto Amax_dQ = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dQ] = + output_tensor(attributes.name + "::Amax_dQ"); + auto Amax_dK = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dK] = + output_tensor(attributes.name + "::Amax_dK"); + auto Amax_dV = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dV] = + output_tensor(attributes.name + "::Amax_dV"); + auto Amax_dP = attributes.outputs[SDPA_fp8_backward_attributes::output_names::Amax_dP] = + output_tensor(attributes.name + "::Amax_dP"); + + // Set inputs + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Q] = q; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::K] = k; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::V] = v; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::O] = o; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Stats] = Stats; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::dO] = dO; + + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_Q] = descale_q; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_K] = descale_k; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_V] = descale_v; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_S] = descale_s; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_O] = descale_o; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_dO] = descale_do; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Descale_dP] = descale_dp; + + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dQ] = scale_dq; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dK] = scale_dk; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dV] = scale_dv; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_S] = scale_s; + attributes.inputs[SDPA_fp8_backward_attributes::input_names::Scale_dP] = scale_dp; + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + + return {dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP}; +} + inline std::array, 3> Graph::sdpa_backward(std::shared_ptr q, std::shared_ptr k, diff --git a/include/cudnn_frontend/graph_properties.h b/include/cudnn_frontend/graph_properties.h index aa123ecf..279dac81 100644 --- a/include/cudnn_frontend/graph_properties.h +++ b/include/cudnn_frontend/graph_properties.h @@ -340,7 +340,7 @@ class Attributes { } // Handle shape and stride inferencing for fused scalars. - // Pick number of dimensions from anyone of non-fused-scalar input tensors + // Pick number of dimensions from anyone of non-fused-scalar input/output tensors // In case, all tensors are fused scalars, just keep them 1D. int64_t number_of_dims = 1; for (auto [name, tensor] : derived->inputs) { @@ -350,6 +350,18 @@ class Attributes { break; } } + + // If number of dims is still 1, try to see if user set output dims. + if (number_of_dims == 1) { + for (auto [name, tensor] : derived->outputs) { + (void)name; + if (tensor && (tensor->get_pass_by_value().has_value() == false)) { + number_of_dims = tensor->get_dim().size(); + break; + } + } + } + for (auto [name, tensor] : derived->inputs) { (void)name; if (tensor && tensor->get_pass_by_value().has_value()) { @@ -773,6 +785,27 @@ class Matmul_attributes : public Attributes { } }; +class Matmul_fp8_attributes : public Attributes { + friend class Attributes; + friend class MatmulFP8Node; + friend class INode; + + double padding_value = 0.0; + + public: + enum class input_names { Descale_A, Descale_B, A, B, Scale_C }; + std::map> inputs; + enum class output_names { C, Amax_C }; + std::map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Matmul_fp8_attributes, name, inputs, outputs) + + Matmul_fp8_attributes& + set_padding(double const padding_val) { + padding_value = padding_val; + return *this; + } +}; + class Pointwise_attributes : public Attributes { friend class Attributes; friend class PointwiseNode; @@ -1069,6 +1102,120 @@ class Rng_attributes : public Attributes { } }; +class Resample_attributes : public Attributes { + friend class Attributes; + friend class ResampleNode; + friend class INode; + + std::optional is_inference; + ResampleMode_t resample_mode; + PaddingMode_t padding_mode; + std::vector pre_padding; + std::vector post_padding; + std::vector stride; + std::vector window; + + public: + enum class input_names { X }; + std::unordered_map> inputs; + + enum class output_names { Y, Index }; + std::unordered_map> outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Resample_attributes, + name, + inputs, + outputs, + resample_mode, + padding_mode, + pre_padding, + post_padding, + stride, + window) + + auto + set_resampling_mode(ResampleMode_t const& value) -> Resample_attributes& { + resample_mode = value; + return *this; + } + + auto + set_padding_mode(PaddingMode_t const& value) -> Resample_attributes& { + padding_mode = value; + return *this; + } + + auto + set_window(std::vector const& value) -> Resample_attributes& { + window.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + window[i].numerator = value[i]; + window[i].denominator = 1; + } + return *this; + } + + auto + set_window(std::vector const& value) -> Resample_attributes& { + window = value; + return *this; + } + + auto + set_stride(std::vector const& value) -> Resample_attributes& { + stride.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + stride[i].numerator = value[i]; + stride[i].denominator = 1; + } + return *this; + } + + auto + set_stride(std::vector const& value) -> Resample_attributes& { + stride = value; + return *this; + } + + auto + set_pre_padding(std::vector const& value) -> Resample_attributes& { + pre_padding.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + pre_padding[i].numerator = value[i]; + pre_padding[i].denominator = 1; + } + return *this; + } + + auto + set_pre_padding(std::vector const& value) -> Resample_attributes& { + pre_padding = value; + return *this; + } + + auto + set_post_padding(std::vector const& value) -> Resample_attributes& { + post_padding.resize(value.size()); + for (auto i = 0u; i < value.size(); i++) { + post_padding[i].numerator = value[i]; + post_padding[i].denominator = 1; + } + return *this; + } + + auto + set_post_padding(std::vector const& value) -> Resample_attributes& { + post_padding = value; + return *this; + } + + auto + set_is_inference(bool const value) -> Resample_attributes& { + is_inference = value; + return *this; + } +}; + class Reshape_attributes : public Attributes { friend class Attributes; friend class ReshapeNode; @@ -1397,6 +1544,66 @@ class SDPA_attributes : public Attributes { } }; +class SDPA_fp8_attributes : public Attributes { + friend class Attributes; + friend class SDPAFP8Node; + friend class Graph; + + std::optional is_inference; + bool causal_mask = false; + std::optional attn_scale_value; + + public: + enum class input_names { + Q, + K, + V, + Attn_scale, + Descale_Q, + Descale_K, + Descale_V, + Descale_S, + Scale_S, + Scale_O, + }; + std::map> inputs; + + enum class output_names { O, Stats, Amax_S, Amax_O }; + std::map> outputs; + + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_attributes, + name, + inputs, + outputs, + is_inference, + causal_mask, + attn_scale_value) + + SDPA_fp8_attributes& + set_is_inference(bool const value) { + is_inference = value; + return *this; + } + + SDPA_fp8_attributes& + set_attn_scale(std::shared_ptr value) { + inputs[SDPA_fp8_attributes::input_names::Attn_scale] = value; + return *this; + } + + SDPA_fp8_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; + return *this; + } + + SDPA_fp8_attributes& + set_causal_mask(bool const value) { + causal_mask = value; + return *this; + } +}; + class SDPA_backward_attributes : public Attributes { friend class Attributes; friend class SDPABackwardNode; @@ -1522,141 +1729,89 @@ class SDPA_backward_attributes : public Attributes { } }; -using Scaled_dot_product_flash_attention_attributes [[deprecated]] = SDPA_attributes; -using Scaled_dot_product_flash_attention_backward_attributes [[deprecated]] = SDPA_backward_attributes; - -class Softmax_attributes : public Attributes { - friend class Attributes; - friend class SoftmaxNode; - friend class INode; +class SDPA_fp8_backward_attributes : public Attributes { + friend class Attributes; + friend class SDPAFP8BackwardNode; + friend class Graph; - std::optional use_stats; - std::optional use_M_Zinv; + bool causal_mask = false; + std::optional attn_scale_value; public: - enum class input_names { P }; - std::map> inputs; - enum class output_names { S, Stats, M, Zinv }; - std::map> outputs; - NLOHMANN_DEFINE_TYPE_INTRUSIVE(Softmax_attributes, name, inputs, outputs, use_stats, use_M_Zinv) - - Softmax_attributes& - has_stats(bool const value) { - use_stats = value; - return *this; - } - - Softmax_attributes& - has_M_Zinv(bool const value) { - use_M_Zinv = value; - return *this; - } -}; - -class SDPA_FP8_attributes : public Attributes { - friend class Attributes; - friend class SDPA_FP8_Node; - friend class Graph; - enum class input_names { Q, K, V, - SEQ_LEN_Q, - SEQ_LEN_KV, + O, + dO, + Stats, Attn_scale, - Bias, - Seed, - Offset, - Dropout_mask, - Dropout_scale, - descale_Q, - descale_K, - descale_V, - scale_S, - scale_O, - ragged_offset_QKV, - ragged_offset_O + Descale_Q, + Descale_K, + Descale_V, + Descale_O, + Descale_dO, + Descale_S, + Descale_dP, + Scale_dQ, + Scale_dK, + Scale_dV, + Scale_S, + Scale_dP, }; std::map> inputs; - enum class output_names { O, Stats, M, Zinv, AMax_S, AMax_O }; + enum class output_names { dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP }; std::map> outputs; - std::optional is_inference; - bool padding_mask = false; - bool causal_mask = false; - std::optional dropout_probability; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_fp8_backward_attributes, name, inputs, outputs, causal_mask, attn_scale_value) - public: - SDPA_FP8_attributes& - set_is_inference(bool const value) { - is_inference = value; + SDPA_fp8_backward_attributes& + set_attn_scale(std::shared_ptr value) { + inputs[SDPA_fp8_backward_attributes::input_names::Attn_scale] = value; return *this; } - SDPA_FP8_attributes& - set_padding_mask(bool const value) { - padding_mask = value; + SDPA_fp8_backward_attributes& + set_attn_scale(float const value) { + attn_scale_value = value; return *this; } - SDPA_FP8_attributes& + SDPA_fp8_backward_attributes& set_causal_mask(bool const value) { causal_mask = value; return *this; } +}; - SDPA_FP8_attributes& - set_attn_scale(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::Attn_scale] = value; - return *this; - } - - SDPA_FP8_attributes& - set_bias(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::Bias] = value; - return *this; - } - - SDPA_FP8_attributes& - set_seq_len_q(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::SEQ_LEN_Q] = value; - return *this; - } +using Scaled_dot_product_flash_attention_attributes [[deprecated]] = SDPA_attributes; +using Scaled_dot_product_flash_attention_backward_attributes [[deprecated]] = SDPA_backward_attributes; - SDPA_FP8_attributes& - set_seq_len_kv(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::SEQ_LEN_KV] = value; - return *this; - } +class Softmax_attributes : public Attributes { + friend class Attributes; + friend class SoftmaxNode; + friend class INode; - SDPA_FP8_attributes& - set_ragged_offset_qkv(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::ragged_offset_QKV] = value; - return *this; - } + std::optional use_stats; + std::optional use_M_Zinv; - SDPA_FP8_attributes& - set_ragged_offset_o(std::shared_ptr value) { - inputs[SDPA_FP8_attributes::input_names::ragged_offset_O] = value; - return *this; - } + public: + enum class input_names { P }; + std::map> inputs; + enum class output_names { S, Stats, M, Zinv }; + std::map> outputs; + NLOHMANN_DEFINE_TYPE_INTRUSIVE(Softmax_attributes, name, inputs, outputs, use_stats, use_M_Zinv) - SDPA_FP8_attributes& - set_dropout(float const probability, - std::shared_ptr seed, - std::shared_ptr offset) { - dropout_probability = probability; - inputs[SDPA_FP8_attributes::input_names::Seed] = seed; - inputs[SDPA_FP8_attributes::input_names::Offset] = offset; + Softmax_attributes& + has_stats(bool const value) { + use_stats = value; return *this; } - SDPA_FP8_attributes& - set_dropout(std::shared_ptr mask, std::shared_ptr scale) { - inputs[SDPA_FP8_attributes::input_names::Dropout_mask] = mask; - inputs[SDPA_FP8_attributes::input_names::Dropout_scale] = scale; + Softmax_attributes& + has_M_Zinv(bool const value) { + use_M_Zinv = value; return *this; } }; diff --git a/include/cudnn_frontend/node/matmul.h b/include/cudnn_frontend/node/matmul.h index 6860fedd..98ced52d 100644 --- a/include/cudnn_frontend/node/matmul.h +++ b/include/cudnn_frontend/node/matmul.h @@ -159,17 +159,17 @@ class MatmulNode : public NodeCRTP { } }; -inline void INode::matmul(std::shared_ptr a, - std::shared_ptr b, - Matmul_attributes attributes, - std::shared_ptr c) { - attributes.inputs[Matmul_attributes::input_names::A] = a; - attributes.inputs[Matmul_attributes::input_names::B] = b; - attributes.outputs[Matmul_attributes::output_names::C] = c; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::matmul(std::shared_ptr a, + std::shared_ptr b, + Matmul_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Matmul_attributes::input_names::A] = a; + attributes.inputs[Matmul_attributes::input_names::B] = b; + attributes.outputs[Matmul_attributes::output_names::C] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } - inline std::shared_ptr INode::matmul(std::shared_ptr a, std::shared_ptr b, @@ -182,4 +182,4 @@ INode::matmul(std::shared_ptr a, return C; } -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/matmul_fp8.h b/include/cudnn_frontend/node/matmul_fp8.h new file mode 100644 index 00000000..a7b233be --- /dev/null +++ b/include/cudnn_frontend/node/matmul_fp8.h @@ -0,0 +1,131 @@ +#pragma once + +#include "../../cudnn_frontend_MatMulDesc.h" +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class MatmulFP8Node : public NodeCRTP { + public: + Matmul_fp8_attributes attributes; + + MatmulFP8Node(Matmul_fp8_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::MATMUL; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating matmul fp8 node " << attributes.name << "..." << std::endl; + + CUDNN_FE_VALIDATE_INPUT_TENSOR(Matmul_fp8_attributes::input_names::A); + CUDNN_FE_VALIDATE_INPUT_TENSOR(Matmul_fp8_attributes::input_names::B); + CUDNN_FE_VALIDATE_INPUT_TENSOR(Matmul_fp8_attributes::input_names::Descale_A); + CUDNN_FE_VALIDATE_INPUT_TENSOR(Matmul_fp8_attributes::input_names::Descale_B); + CUDNN_FE_VALIDATE_INPUT_TENSOR(Matmul_fp8_attributes::input_names::Scale_C); + + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(Matmul_fp8_attributes::output_names::C); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(Matmul_fp8_attributes::output_names::Amax_C); + + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); + + return {error_code_t::OK, ""}; + } + + error_t + expand_and_infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for matmul fp8 node " << attributes.name << "..." + << std::endl; + + attributes.fill_from_context(context); + + auto const& a_dim = attributes.inputs.at(Matmul_fp8_attributes::input_names::A)->get_dim(); + auto const& b_dim = attributes.inputs.at(Matmul_fp8_attributes::input_names::B)->get_dim(); + auto const& c_dim = attributes.outputs.at(Matmul_fp8_attributes::output_names::C)->get_dim(); + + std::shared_ptr last_output; + + // Matmul + auto matmul_attributes = Matmul_attributes().set_name("matmul"); + last_output = matmul(attributes.inputs.at(Matmul_fp8_attributes::input_names::A), + attributes.inputs.at(Matmul_fp8_attributes::input_names::B), + matmul_attributes); + + // Reduction if GQA for head dimension + if (a_dim.size() == 4 && b_dim.size() == 4 && c_dim.size() == 4 && a_dim[1] == b_dim[1] && + a_dim[1] != c_dim[1] && (a_dim[1] % c_dim[1] == 0)) { + auto gqa_attributes = Reduction_attributes().set_name("gqa_c").set_mode(ReductionMode_t::ADD); + last_output = reduction(last_output, gqa_attributes); + last_output->set_dim(c_dim); + } + + //// Scale Descales + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + // Descale A + mul_attributes.set_name("descale_a"); + last_output = + pointwise(last_output, attributes.inputs.at(Matmul_fp8_attributes::input_names::Descale_A), mul_attributes); + + // Descale B + mul_attributes.set_name("descale_b"); + last_output = + pointwise(last_output, attributes.inputs.at(Matmul_fp8_attributes::input_names::Descale_B), mul_attributes); + + // Scale C + mul_attributes.set_name("scale_c"); + // Special non-functional-style call. Needed because output already created and provided to user. + pointwise(last_output, + attributes.inputs.at(Matmul_fp8_attributes::input_names::Scale_C), + mul_attributes, + attributes.outputs.at(Matmul_fp8_attributes::output_names::C)); + + // Amax C + auto amax_attributes = Reduction_attributes().set_name("amax_c").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(last_output, amax_attributes, attributes.outputs.at(Matmul_fp8_attributes::output_names::Amax_C)); + + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { + // Validate outputs + // All properties of output tensors should have been set now. + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_outputs()); + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "MATMUL_FP8"})"_json); + } +}; +inline void +INode::matmul_fp8(std::shared_ptr a, + std::shared_ptr b, + std::shared_ptr descale_a, + std::shared_ptr descale_b, + std::shared_ptr scale_c, + Matmul_fp8_attributes attributes, + std::shared_ptr c, + std::shared_ptr amax_c) { + attributes.inputs[Matmul_fp8_attributes::input_names::A] = a; + attributes.inputs[Matmul_fp8_attributes::input_names::B] = b; + attributes.inputs[Matmul_fp8_attributes::input_names::Descale_A] = descale_a; + attributes.inputs[Matmul_fp8_attributes::input_names::Descale_B] = descale_b; + attributes.inputs[Matmul_fp8_attributes::input_names::Scale_C] = scale_c; + attributes.outputs[Matmul_fp8_attributes::output_names::C] = c; + attributes.outputs[Matmul_fp8_attributes::output_names::Amax_C] = amax_c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +} +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/pointwise.h b/include/cudnn_frontend/node/pointwise.h index 9356e89d..186bc033 100644 --- a/include/cudnn_frontend/node/pointwise.h +++ b/include/cudnn_frontend/node/pointwise.h @@ -160,22 +160,24 @@ class PointwiseNode : public NodeCRTP { } }; -inline void INode::pointwise(std::shared_ptr a, - Pointwise_attributes attributes, - std::shared_ptr c) { - attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; - attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::pointwise(std::shared_ptr a, + Pointwise_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } -inline void INode::pointwise(std::shared_ptr a, - std::shared_ptr b, - Pointwise_attributes attributes, - std::shared_ptr c) { - attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; - attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; - attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::pointwise(std::shared_ptr a, + std::shared_ptr b, + Pointwise_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Pointwise_attributes::input_names::IN_0] = a; + attributes.inputs[Pointwise_attributes::input_names::IN_1] = b; + attributes.outputs[Pointwise_attributes::output_names::OUT_0] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } inline std::shared_ptr @@ -215,5 +217,4 @@ INode::pointwise(std::shared_ptr a, sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); return OUT_0; } - -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/reduction.h b/include/cudnn_frontend/node/reduction.h index f0662f78..fdb60727 100644 --- a/include/cudnn_frontend/node/reduction.h +++ b/include/cudnn_frontend/node/reduction.h @@ -125,12 +125,13 @@ class ReductionNode : public NodeCRTP { } }; -inline void INode::reduction(std::shared_ptr a, - Reduction_attributes attributes, - std::shared_ptr c) { - attributes.inputs[Reduction_attributes::input_names::X] = a; - attributes.outputs[Reduction_attributes::output_names::Y] = c; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::reduction(std::shared_ptr a, + Reduction_attributes attributes, + std::shared_ptr c) { + attributes.inputs[Reduction_attributes::input_names::X] = a; + attributes.outputs[Reduction_attributes::output_names::Y] = c; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } inline std::shared_ptr @@ -141,5 +142,4 @@ INode::reduction(std::shared_ptr input, Reduction_attributes sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); return Y; } - -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/resample.h b/include/cudnn_frontend/node/resample.h new file mode 100644 index 00000000..d2d58cf9 --- /dev/null +++ b/include/cudnn_frontend/node/resample.h @@ -0,0 +1,198 @@ +#pragma once + +#include "../../cudnn_frontend_Resample.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +namespace cudnn_frontend::graph { + +class ResampleNode : public NodeCRTP { + public: + Resample_attributes attributes; + + ResampleNode(Resample_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::RESAMPLE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating ResampleNode " << attributes.name << "..." << std::endl; + + CUDNN_FE_VALIDATE_INPUT_TENSOR(Resample_attributes::input_names::X); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(Resample_attributes::output_names::Y); + + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.is_inference.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "is_inference attribute not set"); + + if (attributes.is_inference.value() == false && attributes.resample_mode == ResampleMode_t::MAXPOOL) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(Resample_attributes::output_names::Index); + } + + // Make sure that the mode can be lowered to BE + cudnnResampleMode_t dummy; + RETURN_CUDNN_FRONTEND_ERROR_IF( + detail::convert_to_cudnn_type(attributes.resample_mode, dummy) != CUDNN_STATUS_SUCCESS, + error_code_t::ATTRIBUTE_NOT_SET, + "Invalid resample mode."); + + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); + + return {error_code_t::OK, ""}; + } + + error_t + expand_and_infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for resample node " << attributes.name << "..." + << std::endl; + + auto y_tensor = attributes.outputs[Resample_attributes::output_names::Y]; + auto x_tensor = attributes.inputs[Resample_attributes::input_names::X]; + + attributes.fill_from_context(context); + + // If user does not set shape and layout of the output tensor, + // Get it from node attributes + if (y_tensor->get_dim().empty()) { + auto const x_dim = x_tensor->get_dim(); + auto y_dim = y_tensor->get_dim(); + y_dim = x_dim; + + // 2 cause first two dimensions are batch and channels + for (auto dim = 2u; dim < x_dim.size(); ++dim) { + auto spatial_dim = dim - 2u; + y_dim[dim] = + 1 + (x_dim[dim] + attributes.pre_padding[spatial_dim].numerator + + attributes.post_padding[spatial_dim].numerator - attributes.window[spatial_dim].numerator) / + attributes.stride[spatial_dim].numerator; + } + + y_tensor->set_dim(y_dim); + } + + // If layout is not set, generate the strides from layout + if (y_tensor->get_stride().empty()) { + auto const& y_dim = y_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(y_dim.size()); + y_tensor->set_stride(detail::generate_stride(y_dim, stride_order)); + } + + if (attributes.outputs[Resample_attributes::output_names::Index]) { + auto index_tensor = attributes.outputs[Resample_attributes::output_names::Index]; + index_tensor->set_dim(y_tensor->get_dim()); + + // If layout is not set, generate the strides from layout + if (index_tensor->get_stride().empty()) { + auto const& index_dim = index_tensor->get_dim(); + // Default to NHWC + auto const& stride_order = detail::generate_NHWC_stride_order(index_dim.size()); + index_tensor->set_stride(detail::generate_stride(index_dim, stride_order)); + } + } + + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { + // Validate outputs + // All properties of output tensors should have been set now. + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_outputs()); + + return {error_code_t::OK, ""}; + } + + error_t + create_cudnn_operations( + std::unordered_set& uids_involved_in_operations, + std::vector>& operations, + std::unordered_map>& tensors) const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Building ResampleNode operations " << attributes.name << "..." << std::endl; + + auto number_of_spatial_dim = static_cast(attributes.window.size()); + + // Define the resample descriptor + auto resample_descriptor = cudnn_frontend::ResampleDescBuilder_v8() + .setComputeType(attributes.compute_data_type) + .setNanPropagation(CUDNN_PROPAGATE_NAN) + .setResampleMode(attributes.resample_mode) + .setPaddingMode(attributes.padding_mode) + .setSpatialDim(number_of_spatial_dim, attributes.window.data()) + .setSpatialStride(number_of_spatial_dim, attributes.stride.data()) + .setPrePadding(number_of_spatial_dim, attributes.pre_padding.data()) + .setPostPadding(number_of_spatial_dim, attributes.post_padding.data()) + .build(); + + auto&& resample_op_builder = + cudnn_frontend::OperationBuilder(DescriptorType_t::OPERATION_RESAMPLE_FWD_DESCRIPTOR); + + CUDNN_FE_VALIDATE_AND_ASSIGN_INPUT_TENSOR(X, Resample_attributes::input_names::X); + resample_op_builder.setxDesc(*(tensors.at(X->second->get_uid()))); + + CUDNN_FE_VALIDATE_AND_ASSIGN_OUTPUT_TENSOR(Y, Resample_attributes::output_names::Y); + resample_op_builder.setyDesc(*(tensors.at(Y->second->get_uid()))); + + resample_op_builder.setResampleDesc(resample_descriptor); + + auto index = attributes.outputs.find(Resample_attributes::output_names::Index); + if ((index != attributes.outputs.end()) && (index->second != nullptr)) { + resample_op_builder.setidxDesc(*tensors.at(index->second->get_uid())); + } + +#ifdef NV_CUDNN_DISABLE_EXCEPTION + // disable exception macro is defined. Calling build will not throw. + // Check status of desc and return error. + auto operation = resample_op_builder.build(); + RETURN_CUDNN_FRONTEND_ERROR_IF(operation.get_status() != CUDNN_STATUS_SUCCESS, + error_code_t::CUDNN_BACKEND_API_FAILED, + operation.get_error()); + operations.push_back(std::make_shared(std::move(operation))); +#else + // build() can throw + // wrap in try catch + try { + auto operation = resample_op_builder.build(); + operations.push_back(std::make_shared(std::move(operation))); + } catch (cudnn_frontend::cudnnException& e) { + RETURN_CUDNN_FRONTEND_ERROR_IF( + e.getCudnnStatus() != CUDNN_STATUS_SUCCESS, error_code_t::CUDNN_BACKEND_API_FAILED, e.what()); + } +#endif + + auto const& non_virtual_uids = attributes.get_non_virtual_uids(); + uids_involved_in_operations.insert(non_virtual_uids.begin(), non_virtual_uids.end()); + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"( {"tag": "RESAMPLE"})"_json); + } +}; + +inline std::array, 2> +INode::resample(std::shared_ptr input, Resample_attributes attributes) { + attributes.inputs[Resample_attributes::input_names::X] = input; + auto Y = attributes.outputs[Resample_attributes::output_names::Y] = output_tensor(attributes.name + "::Y"); + std::shared_ptr Index = nullptr; + if (attributes.is_inference.has_value() && attributes.is_inference.value() == false && + attributes.resample_mode == ResampleMode_t::MAXPOOL) { + Index = attributes.outputs[Resample_attributes::output_names::Index] = + output_tensor(attributes.name + "::Index"); + } + + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); + return {Y, Index}; +} + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/reshape.h b/include/cudnn_frontend/node/reshape.h index f16b35cd..86567b75 100644 --- a/include/cudnn_frontend/node/reshape.h +++ b/include/cudnn_frontend/node/reshape.h @@ -139,4 +139,4 @@ INode::reshape(std::shared_ptr input, Reshape_attributes attr return Y; } -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/rng.h b/include/cudnn_frontend/node/rng.h index 46bedca8..71c2b6af 100644 --- a/include/cudnn_frontend/node/rng.h +++ b/include/cudnn_frontend/node/rng.h @@ -144,14 +144,15 @@ class RngNode : public NodeCRTP { } }; -inline void INode::rng(std::shared_ptr seed, - std::shared_ptr offset, - Rng_attributes attributes, - std::shared_ptr y) { - attributes.inputs[Rng_attributes::input_names::Seed] = seed; - attributes.inputs[Rng_attributes::input_names::Offset] = offset; - attributes.outputs[Rng_attributes::output_names::Y] = y; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::rng(std::shared_ptr seed, + std::shared_ptr offset, + Rng_attributes attributes, + std::shared_ptr y) { + attributes.inputs[Rng_attributes::input_names::Seed] = seed; + attributes.inputs[Rng_attributes::input_names::Offset] = offset; + attributes.outputs[Rng_attributes::output_names::Y] = y; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } inline std::shared_ptr @@ -165,5 +166,4 @@ INode::rng(std::shared_ptr seed, sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); return Y; } - -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h index a8ef75ef..c299de34 100644 --- a/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h +++ b/include/cudnn_frontend/node/scaled_dot_product_flash_attention.h @@ -19,6 +19,7 @@ class SDPANode : public NodeCRTP { std::shared_ptr rng_output; std::shared_ptr alibi_slopes; + int64_t alibi_slopes_size = 0; public: SDPA_attributes attributes; @@ -276,6 +277,7 @@ class SDPANode : public NodeCRTP { .set_stride({h, 1, 1, 1}) // Hard code data type float as FE itself will compute and place in variant pack later .set_data_type(DataType_t::FLOAT); + alibi_slopes_size = h * sizeof(float); auto mul_attributes = Pointwise_attributes().set_name("mul").set_mode(PointwiseMode_t::MUL); auto const& alibi_mask = pointwise(sub_output, alibi_slopes, mul_attributes); @@ -502,10 +504,12 @@ class SDPANode : public NodeCRTP { virtual int64_t get_fe_workspace_size_node() const override final { - auto const& q = attributes.inputs.find(input_names::Q); - int64_t const h = q->second->get_dim()[1]; - int64_t alibi_slopes_size = h * sizeof(float); - return (alibi_slopes_size + 15) & ~15; + int64_t size = 0; + + // align alibi slopes memory to 16 bytes + size += ((alibi_slopes_size + 15) / 16 * 16); + + return size; } virtual error_t @@ -517,6 +521,8 @@ class SDPANode : public NodeCRTP { int64_t const h_q = Q->second->get_dim()[1]; auto alibi_slopes_vec = detail::get_abili_slope(h_q); workspace_modifications.emplace(alibi_slopes->get_uid(), std::make_tuple(0, offset, alibi_slopes_vec)); + int64_t alibi_slopes_size_padded = ((alibi_slopes_size + 15) / 16 * 16); + offset = offset + alibi_slopes_size_padded; } return {error_code_t::OK, ""}; } @@ -1253,10 +1259,13 @@ class SDPABackwardNode : public NodeCRTP { virtual int64_t get_fe_workspace_size_node() const override final { - // set in infer_properties_node() - // align alibi slopes memory to 16 bytes - int64_t alibi_slopes_size_padded = (alibi_slopes_size + 15) & ~15; - return alibi_slopes_size_padded + dQ_accum_size + softmax_sum_size; + int64_t size = 0; + + size += ((alibi_slopes_size + 15) / 16 * 16); // align alibi slopes memory to 16 bytes + size += dQ_accum_size; + size += softmax_sum_size; + + return size; } virtual error_t @@ -1268,7 +1277,7 @@ class SDPABackwardNode : public NodeCRTP { int64_t const h_q = Q->second->get_dim()[1]; auto alibi_slopes_vec = detail::get_abili_slope(h_q); workspace_modifications.emplace(alibi_slopes->get_uid(), std::make_tuple(0, offset, alibi_slopes_vec)); - int64_t alibi_slopes_size_padded = (alibi_slopes_size + 15) & ~15; + int64_t alibi_slopes_size_padded = ((alibi_slopes_size + 15) / 16 * 16); offset = offset + alibi_slopes_size_padded; } @@ -1294,4 +1303,4 @@ class SDPABackwardNode : public NodeCRTP { } }; -} // namespace cudnn_frontend::graph \ No newline at end of file +} // namespace cudnn_frontend::graph diff --git a/include/cudnn_frontend/node/sdpa_fp8.h b/include/cudnn_frontend/node/sdpa_fp8.h new file mode 100644 index 00000000..6e0ff06a --- /dev/null +++ b/include/cudnn_frontend/node/sdpa_fp8.h @@ -0,0 +1,263 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "matmul_fp8.h" +#include "pointwise.h" +#include "reduction.h" +#include "softmax.h" + +namespace cudnn_frontend::graph { + +class SDPAFP8Node : public NodeCRTP { + using input_names = SDPA_fp8_attributes::input_names; + using output_names = SDPA_fp8_attributes::output_names; + + public: + SDPA_fp8_attributes attributes; + + SDPAFP8Node(SDPA_fp8_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating SDPAFP8Node " << attributes.name << "..." << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF(cudnn_frontend::get_backend_version() < 90100, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported starting cudnn 9.1.0. Please " + "consider upgrading your current version."); + + cudaDeviceProp prop; + int device; + CHECK_CUDA_ERROR(cuda_get_device(&device)); + CHECK_CUDA_ERROR(cuda_get_device_properties(&prop, device)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + prop.major < 9, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Hopper architecture and newer. Please " + "consider using a newer architecture."); + + // check that Q, K, V, O tensors has been assigned + // check that dim and strides has been assigned and last stride is 1 +#define CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(port, port_map) \ + { \ + std::shared_ptr tensor_ptr = port_map.at(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_dim().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The dim for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_stride().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The stride for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + tensor_ptr->get_stride()[3] != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::Q); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Q, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::K); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::K, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::V); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::V, attributes.inputs); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::O); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::O, attributes.outputs); + + // validate options for is_inference and stats tensor + RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.is_inference.has_value() == false, + error_code_t::ATTRIBUTE_NOT_SET, + "is_inference attribute not set"); + + if (attributes.is_inference.value() == false) { + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::Stats); + } + +#undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); + return {error_code_t::OK, ""}; + } + + error_t + expand_and_infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for Scaled_dot_product_flash_attention node " + << attributes.name << "..." << std::endl; + + // DO NOT REMOVE + // input data type is needed for: + // - aType of bmm2 + attributes.fill_from_context(context); + + // Gather dim to fill properties of virtual tensors + // auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + // auto b = q_dim[0]; + // auto h = q_dim[1]; + // auto s_q = q_dim[2]; + // auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + // auto s_kv = k_dim[2]; + + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, V + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // V = {b, h_v, s_kv, d_v} + // So the code below maps the K->KT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + // This tensor tracks the main chain of data flow + std::shared_ptr last_output; + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + + //// Q * K + auto bmm1_attributes = Matmul_attributes().set_name("bmm1").set_padding(0.0); + last_output = matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm1_attributes); + + //// Optional Attn scale + // In case user provided a scalar value, do a fused scalar. + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + // If attn scale present, add a pointwise mul node + if (attributes.inputs[input_names::Attn_scale]) { + mul_attributes.set_name("attn_scale"); + auto const& attn_scale_output = + pointwise(last_output, attributes.inputs[input_names::Attn_scale], mul_attributes); + last_output = attn_scale_output; + } + + //// Descales + // Descale Q + mul_attributes.set_name("descale_q"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_Q), mul_attributes); + + // Descale K + mul_attributes.set_name("descale_k"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_K), mul_attributes); + + //// Optional causal masking + if (attributes.causal_mask) { + auto row_index_attributes = + Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + auto const& row_index_output = pointwise(last_output, row_index_attributes); + + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + auto const& col_index_output = pointwise(last_output, col_index_attributes); + + auto greater_than_attributes = Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN); + auto const& row_greater_than_col_output = + pointwise(row_index_output, col_index_output, greater_than_attributes); + row_greater_than_col_output->set_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_causal = std::make_shared(std::numeric_limits::lowest()); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + auto const& causal_mask_output = + pointwise(last_output, negative_inf_causal, row_greater_than_col_output, binary_select_attributes); + last_output = causal_mask_output; + } + + //// Softmax + // softmax output, S, is always virtual. + auto softmax_output = std::make_shared(); + softmax_output->set_is_virtual(true); + + // Create virtual stats if inference step otherwise output.Stats should be provided by user. + auto softmax_stats = attributes.outputs[output_names::Stats]; + if (attributes.is_inference.value() == true) { + softmax_stats = std::make_shared(); + softmax_stats->set_is_virtual(true); + } + + auto softmax_attributes = + Softmax_attributes().set_name("softmax").has_stats(true).has_M_Zinv(false); // As this is flash attention + // Special non-functional-style call. Needed because output already created and provided to user. + softmax(last_output, softmax_attributes, softmax_output, softmax_stats); + last_output = softmax_output; + + // Amax S + auto amax_attributes = Reduction_attributes().set_name("amax_s").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(last_output, amax_attributes, attributes.outputs.at(output_names::Amax_S)); + + // Scale S + mul_attributes.set_name("scale_s"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Scale_S), mul_attributes); + last_output->set_data_type(attributes.inputs.at(input_names::Q)->get_data_type()); + + //// S * V + auto bmm2_attributes = Matmul_fp8_attributes().set_name("bmm2"); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(last_output, + attributes.inputs.at(input_names::V), + attributes.inputs.at(input_names::Descale_S), + attributes.inputs.at(input_names::Descale_V), + attributes.inputs.at(input_names::Scale_O), + bmm2_attributes, + attributes.outputs.at(output_names::O), + attributes.outputs.at(output_names::Amax_O)); + + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { +#define CUDNN_FE_VALIDATE_STRIDE(port, port_map) \ + { \ + auto const& t = port_map.find(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + t->second->get_stride().back() != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_VALIDATE_STRIDE(output_names::O, attributes.outputs); + +#undef CUDNN_FE_VALIDATE_STRIDE + + // Validate outputs + // All properties of output tensors should have been set now. + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_outputs()); + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "SDPA_FP8_FWD"})"_json); + } +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/sdpa_fp8_bwd.h b/include/cudnn_frontend/node/sdpa_fp8_bwd.h new file mode 100644 index 00000000..b7c9c8a7 --- /dev/null +++ b/include/cudnn_frontend/node/sdpa_fp8_bwd.h @@ -0,0 +1,345 @@ +#pragma once + +#include "../../cudnn_frontend_Heuristics.h" +#include "../../cudnn_frontend_Logging.h" + +#include "../graph_helpers.h" +#include "../node_interface.h" + +#include "matmul_fp8.h" +#include "pointwise.h" +#include "reduction.h" +#include "softmax.h" + +namespace cudnn_frontend::graph { + +class SDPAFP8BackwardNode : public NodeCRTP { + using input_names = SDPA_fp8_backward_attributes::input_names; + using output_names = SDPA_fp8_backward_attributes::output_names; + + public: + SDPA_fp8_backward_attributes attributes; + + SDPAFP8BackwardNode(SDPA_fp8_backward_attributes&& attributes_, detail::Context const& context) + : NodeCRTP(context), attributes(std::move(attributes_)) {} + + Type + getType() override final { + return Type::COMPOSITE; + } + + error_t + pre_validate_node() const override final { + getLogger() << "[cudnn_frontend] INFO: " + << "Validating SDPAFP8BackwardNode " << attributes.name << "..." << std::endl; + + RETURN_CUDNN_FRONTEND_ERROR_IF(cudnn_frontend::get_backend_version() < 90100, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 backward operation is only supported starting cudnn 9.1.0. Please " + "consider upgrading your current version."); + + cudaDeviceProp prop; + int device; + CHECK_CUDA_ERROR(cuda_get_device(&device)); + CHECK_CUDA_ERROR(cuda_get_device_properties(&prop, device)); + RETURN_CUDNN_FRONTEND_ERROR_IF( + prop.major < 9, + error_code_t::GRAPH_NOT_SUPPORTED, + "sdpa fp8 forward operation is only supported on Hopper architecture and newer. Please " + "consider using a newer architecture."); + + // check that Q, K, V, O, stats, dO, dQ, dK, dV tensors has been assigned + // check that dim and strides has been assigned and last stride is 1 +#define CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(port, port_map) \ + { \ + std::shared_ptr tensor_ptr = port_map.at(port); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_dim().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The dim for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF(tensor_ptr->get_stride().size() != 4, \ + error_code_t::ATTRIBUTE_NOT_SET, \ + "The stride for " + std::string(#port) + " is invalid"); \ + RETURN_CUDNN_FRONTEND_ERROR_IF( \ + tensor_ptr->get_stride()[3] != 1, \ + error_code_t::GRAPH_NOT_SUPPORTED, \ + "The stride for the last dimension corresponding to the embedding size per head should be 1 for " + \ + std::string(#port)); \ + } + + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::Q); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Q, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::K); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::K, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::V); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::V, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::O); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::O, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::Stats); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::Stats, attributes.inputs); + CUDNN_FE_VALIDATE_INPUT_TENSOR(input_names::dO); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(input_names::dO, attributes.inputs); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::dQ); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dQ, attributes.outputs); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::dK); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dK, attributes.outputs); + CUDNN_FE_VALIDATE_OUTPUT_TENSOR(output_names::dV); + CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE(output_names::dV, attributes.outputs); + +#undef CUDNN_FE_SDPA_VALIDATE_DIM_STRIDE + + // validate options for attn_scale + auto const& attn_scale = attributes.inputs.find(input_names::Attn_scale); + bool const has_attn_scale = (attn_scale != attributes.inputs.end()) && (attn_scale->second != nullptr); + RETURN_CUDNN_FRONTEND_ERROR_IF(has_attn_scale && attributes.attn_scale_value.has_value(), + error_code_t::ATTRIBUTE_NOT_SET, + "attn_scale with tensor and value cannot be set at the same time."); + + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_inputs()); + return {error_code_t::OK, ""}; + } + + error_t + expand_and_infer_properties_node() override final { + getLogger() << "[cudnn_frontend] INFO: Inferrencing properties for Scaled_dot_product_flash_attention node " + << attributes.name << "..." << std::endl; + + attributes.fill_from_context(context); + + // Gather dim to fill properties of virtual tensors + auto const& q_dim = attributes.inputs[input_names::Q]->get_dim(); + auto b = q_dim[0]; + auto h_q = q_dim[1]; + auto s_q = q_dim[2]; + // auto d_qk = q_dim[3]; + auto const& k_dim = attributes.inputs[input_names::K]->get_dim(); + // auto h_k = k_dim[1]; + auto s_kv = k_dim[2]; + // auto const& v_dim = attributes.inputs[input_names::V]->get_dim(); + // auto h_v = v_dim[1]; + // auto d_v = v_dim[3]; + + // cuDNN frontend API attention requires Q, K, V where + // Q = {b, h_q, s_q, d_qk} + // K = {b, h_k, s_kv, d_qk} + // V = {b, h_v, s_kv, d_v} + // but cuDNN backend API attention requires Q, KT, VT + // Q = {b, h_q, s_q, d_qk} + // KT = {b, h_k, d_qk, s_kv} + // VT = {b, h_v, d_v, s_kv} + // So the code below maps the K->KT and V->VT + std::vector temp_vec; + + temp_vec = attributes.inputs[input_names::K]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::K]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::K]->set_stride(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_dim(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_dim(temp_vec); + + temp_vec = attributes.inputs[input_names::V]->get_stride(); + std::swap(temp_vec[2], temp_vec[3]); + attributes.inputs[input_names::V]->set_stride(temp_vec); + + auto mul_attributes = Pointwise_attributes().set_mode(PointwiseMode_t::MUL); + + //// dO * O + mul_attributes.set_name("mul_dO_O"); + auto last_output = + pointwise(attributes.inputs[input_names::dO], attributes.inputs[input_names::O], mul_attributes); + + // reduce(dO) + last_output = + reduction(last_output, Reduction_attributes().set_name("reduce_dO").set_mode(ReductionMode_t::ADD)); + last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1}); + + // Descale dO + mul_attributes.set_name("descale_dO"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_dO), mul_attributes); + + // Descale O + mul_attributes.set_name("descale_O"); + auto softmax_sum = pointwise(last_output, attributes.inputs.at(input_names::Descale_O), mul_attributes); + + //// Q * K + auto bmm_Q_K_attributes = Matmul_attributes().set_name("bmm_Q_K"); + auto last_dV = matmul(attributes.inputs[input_names::Q], attributes.inputs[input_names::K], bmm_Q_K_attributes); + + //// Optional Attn scale + // In case user provided a scalar value, do a fused scalar. + if (attributes.attn_scale_value.has_value()) { + attributes.inputs[input_names::Attn_scale] = + std::make_shared(attributes.attn_scale_value.value()); + } + + // If attn scale present, add a pointwise mul node + if (attributes.inputs[input_names::Attn_scale]) { + mul_attributes.set_name("attn_scale"); + last_dV = pointwise(last_dV, attributes.inputs[input_names::Attn_scale], mul_attributes); + } + + //// Descales + // Descale Q + mul_attributes.set_name("descale_q"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Descale_Q), mul_attributes); + + // Descale K + mul_attributes.set_name("descale_k"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Descale_K), mul_attributes); + + //// Optional causal masking + if (attributes.causal_mask) { + auto row_index_attributes = + Pointwise_attributes().set_name("gen_row_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(2); + auto const& row_index_output = pointwise(last_dV, row_index_attributes); + + auto col_index_attributes = + Pointwise_attributes().set_name("gen_col_index").set_mode(PointwiseMode_t::GEN_INDEX).set_axis(3); + auto const& col_index_output = pointwise(last_dV, col_index_attributes); + + auto greater_than_attributes = Pointwise_attributes() + .set_name("row_greater_than_col") + .set_mode(PointwiseMode_t::CMP_GE) + .set_compute_data_type(DataType_t::BOOLEAN); + auto const& row_greater_than_col_output = + pointwise(row_index_output, col_index_output, greater_than_attributes); + row_greater_than_col_output->set_data_type(DataType_t::BOOLEAN); + + // Lower attributes to binary select attributes + auto negative_inf_causal = std::make_shared(std::numeric_limits::lowest()); + + auto binary_select_attributes = + Pointwise_attributes().set_name("binary_select").set_mode(PointwiseMode_t::BINARY_SELECT); + last_dV = pointwise(last_dV, negative_inf_causal, row_greater_than_col_output, binary_select_attributes); + } + + //// Apply Softmax + // last_dV = last_dV - stats + last_dV = pointwise(last_dV, + attributes.inputs[input_names::Stats], + Pointwise_attributes().set_name("sub_dV_Stats").set_mode(PointwiseMode_t::SUB)); + + // last_dV = exp(last_dV) + last_dV = pointwise(last_dV, Pointwise_attributes().set_name("exp_dV").set_mode(PointwiseMode_t::EXP)); + auto exp_S = last_dV; + + // Scale S + mul_attributes.set_name("scale_S"); + last_dV = pointwise(last_dV, attributes.inputs.at(input_names::Scale_S), mul_attributes); + last_dV->set_data_type(attributes.inputs.at(input_names::Q)->get_data_type()); + + // Reshape S + last_dV = reshape(last_dV, Reshape_attributes().set_name("S_transpose")); + last_dV->set_name("S_T").set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + last_dV->set_data_type(attributes.inputs[input_names::Q]->get_data_type()); + + //// S_T * dO + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(last_dV, + attributes.inputs[input_names::dO], + attributes.inputs[input_names::Descale_S], + attributes.inputs[input_names::Descale_dO], + attributes.inputs[input_names::Scale_dV], + Matmul_fp8_attributes().set_name("bmm_S_T_dO"), + attributes.outputs[output_names::dV], + attributes.outputs[output_names::Amax_dV]); + + //// dO * V_T + auto bmm_dO_V_T_attributes = Matmul_attributes().set_name("bmm_dO_V_T"); + last_output = + matmul(attributes.inputs[input_names::dO], attributes.inputs[input_names::V], bmm_dO_V_T_attributes); + + //// Descales + // Descale dO + mul_attributes.set_name("descale_dO"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_dO), mul_attributes); + + // Descale V + mul_attributes.set_name("descale_V"); + last_output = pointwise(last_output, attributes.inputs.at(input_names::Descale_V), mul_attributes); + + // dP = last_output - softmax_sum + auto dP = pointwise(last_output, + softmax_sum, + Pointwise_attributes().set_name("sub_dP_softmax_sum").set_mode(PointwiseMode_t::SUB)); + + // dP = dP * exp_S + mul_attributes.set_name("mul_dP_exp_S"); + dP = pointwise(dP, exp_S, mul_attributes); + + // (optional) dP = dP * attn_scale + if (attributes.inputs[input_names::Attn_scale]) { + mul_attributes.set_name("mul_dS_attn_scale"); + dP = pointwise(dP, attributes.inputs[input_names::Attn_scale], mul_attributes); + } + + // Amax dP + auto amax_attributes = Reduction_attributes().set_name("amax_dP").set_mode(ReductionMode_t::AMAX); + // Special non-functional-style call. Needed because output already created and provided to user. + reduction(dP, amax_attributes, attributes.outputs.at(output_names::Amax_dP)); + + // Scale dP + mul_attributes.set_name("scale_dP"); + dP = pointwise(dP, attributes.inputs.at(input_names::Scale_dP), mul_attributes); + dP->set_data_type(attributes.inputs.at(input_names::dO)->get_data_type()); + + //// dP * K + auto const& kt_dim = attributes.inputs[input_names::K]->get_dim(); + auto const& kt_stride = attributes.inputs[input_names::K]->get_stride(); + + auto K = reshape(attributes.inputs[input_names::K], Reshape_attributes().set_name("reshape_K")); + K->set_dim({kt_dim[0], kt_dim[1], kt_dim[3], kt_dim[2]}) + .set_stride({kt_stride[0], kt_stride[1], kt_stride[3], kt_stride[2]}); + + auto bmm_dP_K_attributes = Matmul_fp8_attributes().set_name("bmm_dP_K"); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(dP, + K, + attributes.inputs[input_names::Descale_dP], + attributes.inputs[input_names::Descale_K], + attributes.inputs[input_names::Scale_dQ], + bmm_dP_K_attributes, + attributes.outputs[output_names::dQ], + attributes.outputs[output_names::Amax_dQ]); + + //// dP.T * Q + auto dP_T_attributes = Reshape_attributes().set_name("dP_T"); + auto dP_T = reshape(dP, dP_T_attributes); + dP_T->set_data_type(attributes.inputs.at(input_names::dO)->get_data_type()); + dP_T->set_name("dP_T").set_dim({b, h_q, s_kv, s_q}).set_stride({h_q * s_q * s_kv, s_q * s_kv, 1, s_kv}); + + auto bmm_dP_T_Q_attributes = Matmul_fp8_attributes().set_name("bmm_dP_T_Q"); + // Special non-functional-style call. Needed because output already created and provided to user. + matmul_fp8(dP_T, + attributes.inputs[input_names::Q], + attributes.inputs[input_names::Descale_dP], + attributes.inputs[input_names::Descale_Q], + attributes.inputs[input_names::Scale_dK], + bmm_dP_T_Q_attributes, + attributes.outputs[output_names::dK], + attributes.outputs[output_names::Amax_dK]); + + return {error_code_t::OK, ""}; + } + + error_t + post_validate_node() const override final { + // Validate outputs + // All properties of output tensors should have been set now. + CHECK_CUDNN_FRONTEND_ERROR(attributes.validate_outputs()); + + return {error_code_t::OK, ""}; + } + + virtual void + serialize(json& j) const override final { + j = attributes; + j.update(R"({"tag": "SDPA_FP8_BWD"})"_json); + } +}; + +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node/softmax.h b/include/cudnn_frontend/node/softmax.h index 4c5852e1..e0064027 100644 --- a/include/cudnn_frontend/node/softmax.h +++ b/include/cudnn_frontend/node/softmax.h @@ -129,26 +129,27 @@ class SoftmaxNode : public NodeCRTP { } }; -inline void INode::softmax(std::shared_ptr p, - Softmax_attributes attributes, - std::shared_ptr s, - std::shared_ptr stats) { - attributes.inputs[Softmax_attributes::input_names::P] = p; - attributes.outputs[Softmax_attributes::output_names::S] = s; - attributes.outputs[Softmax_attributes::output_names::Stats] = stats; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::softmax(std::shared_ptr p, + Softmax_attributes attributes, + std::shared_ptr s, + std::shared_ptr stats) { + attributes.inputs[Softmax_attributes::input_names::P] = p; + attributes.outputs[Softmax_attributes::output_names::S] = s; + attributes.outputs[Softmax_attributes::output_names::Stats] = stats; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } -inline void INode::softmax(std::shared_ptr p, - Softmax_attributes attributes, - std::shared_ptr s, - std::shared_ptr m, - std::shared_ptr zinv) { - attributes.inputs[Softmax_attributes::input_names::P] = p; - attributes.outputs[Softmax_attributes::output_names::S] = s; - attributes.outputs[Softmax_attributes::output_names::M] = m; - attributes.outputs[Softmax_attributes::output_names::Zinv] = zinv; - sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); +inline void +INode::softmax(std::shared_ptr p, + Softmax_attributes attributes, + std::shared_ptr s, + std::shared_ptr m, + std::shared_ptr zinv) { + attributes.inputs[Softmax_attributes::input_names::P] = p; + attributes.outputs[Softmax_attributes::output_names::S] = s; + attributes.outputs[Softmax_attributes::output_names::M] = m; + attributes.outputs[Softmax_attributes::output_names::Zinv] = zinv; + sub_nodes.emplace_back(std::make_unique(std::move(attributes), context)); } - -} // namespace cudnn_frontend::graph +} // namespace cudnn_frontend::graph \ No newline at end of file diff --git a/include/cudnn_frontend/node_interface.h b/include/cudnn_frontend/node_interface.h index a9164144..df28a939 100644 --- a/include/cudnn_frontend/node_interface.h +++ b/include/cudnn_frontend/node_interface.h @@ -23,12 +23,13 @@ namespace cudnn_frontend { namespace graph { - class BatchNormNode; class DBNNode; class MatmulNode; +class MatmulFP8Node; class PointwiseNode; class ReductionNode; +class ResampleNode; class ReshapeNode; class RngNode; class SoftmaxNode; @@ -243,36 +244,46 @@ class INode : public ICudnn { Matmul_attributes attributes, std::shared_ptr c); - inline void + void + matmul_fp8(std::shared_ptr a, + std::shared_ptr b, + std::shared_ptr descale_a, + std::shared_ptr descale_b, + std::shared_ptr scale_c, + Matmul_fp8_attributes attributes, + std::shared_ptr c, + std::shared_ptr amax_c); + + void softmax(std::shared_ptr p, Softmax_attributes attributes, std::shared_ptr s, std::shared_ptr stats); - inline void + void softmax(std::shared_ptr p, Softmax_attributes attributes, std::shared_ptr s, std::shared_ptr m, std::shared_ptr zinv); - inline void + void pointwise(std::shared_ptr a, Pointwise_attributes attributes, std::shared_ptr c); - inline void + void pointwise(std::shared_ptr a, std::shared_ptr b, Pointwise_attributes attributes, std::shared_ptr c); - inline void + void reduction(std::shared_ptr a, Reduction_attributes attributes, std::shared_ptr c); - inline void + void rng(std::shared_ptr seed, std::shared_ptr offset, Rng_attributes attributes, @@ -353,25 +364,26 @@ class INode : public ICudnn { virtual Type getType() = 0; - inline std::shared_ptr matmul(std::shared_ptr, - std::shared_ptr, - Matmul_attributes); - - inline std::shared_ptr pointwise(std::shared_ptr, Pointwise_attributes); - inline std::shared_ptr pointwise(std::shared_ptr, - std::shared_ptr, - Pointwise_attributes); - inline std::shared_ptr pointwise(std::shared_ptr, - std::shared_ptr, - std::shared_ptr, - Pointwise_attributes); - - inline std::shared_ptr reduction(std::shared_ptr, Reduction_attributes); - inline std::shared_ptr reshape(std::shared_ptr, Reshape_attributes); - - inline std::shared_ptr rng(std::shared_ptr, - std::shared_ptr, - Rng_attributes); + std::shared_ptr matmul(std::shared_ptr, + std::shared_ptr, + Matmul_attributes); + + std::shared_ptr pointwise(std::shared_ptr, Pointwise_attributes); + std::shared_ptr pointwise(std::shared_ptr, + std::shared_ptr, + Pointwise_attributes); + std::shared_ptr pointwise(std::shared_ptr, + std::shared_ptr, + std::shared_ptr, + Pointwise_attributes); + + std::shared_ptr reduction(std::shared_ptr, Reduction_attributes); + std::array, 2> resample(std::shared_ptr, Resample_attributes); + std::shared_ptr reshape(std::shared_ptr, Reshape_attributes); + + std::shared_ptr rng(std::shared_ptr, + std::shared_ptr, + Rng_attributes); error_t validate() { // infer_properties self @@ -749,4 +761,4 @@ class NodeCRTP : public INode { } // namespace graph -} // namespace cudnn_frontend +} // namespace cudnn_frontend \ No newline at end of file diff --git a/include/cudnn_frontend/plans.h b/include/cudnn_frontend/plans.h index cf8da82a..c017b5b3 100644 --- a/include/cudnn_frontend/plans.h +++ b/include/cudnn_frontend/plans.h @@ -221,10 +221,11 @@ class Execution_plan_list { std::string operation_tag; std::vector> numeric_notes; std::vector> behavior_notes; - std::vector filtered_indices; + std::vector barred_indices; int64_t max_workspace_allowed = std::numeric_limits::max(); + std::vector barred_engine_names = {}; EngineConfigList engine_configs; public: @@ -254,7 +255,7 @@ class Execution_plan_list { numeric_notes.reserve(engine_configs.size()); behavior_notes.reserve(engine_configs.size()); - filtered_indices.resize(engine_configs.size(), 0); + barred_indices.resize(engine_configs.size(), 0); execution_plans.resize(engine_configs.size()); for (auto& engine_config : engine_configs) { @@ -321,38 +322,33 @@ class Execution_plan_list { } error_t - deselect_numeric_notes(std::vector const& notes) { + filter_numeric_notes(std::vector const& notes, bool const keep) { for (auto& note : notes) { cudnnBackendNumericalNote_t backend_note; - RETURN_CUDNN_FRONTEND_ERROR_IF(detail::convert_to_cudnn_type(note, backend_note) != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - "Unexpected behaviour note provided."); - + auto valid_note = (detail::convert_to_cudnn_type(note, backend_note) == CUDNN_STATUS_SUCCESS); for (auto i = 0u; i < engine_configs.size(); i++) { - if (std::find(numeric_notes[i].begin(), numeric_notes[i].end(), backend_note) != - numeric_notes[i].end()) { - filtered_indices[i] = true; - } + bool has_barred_note = + std::find(numeric_notes[i].begin(), numeric_notes[i].end(), backend_note) != numeric_notes[i].end(); + + barred_indices[i] = has_barred_note && valid_note ? !keep : keep; } } return {error_code_t::OK, ""}; } error_t - deselect_behavior_notes(std::vector const& notes) { + filter_behavior_notes(std::vector const& notes, bool const keep) { for (auto& note : notes) { cudnnBackendBehaviorNote_t backend_note; - RETURN_CUDNN_FRONTEND_ERROR_IF(detail::convert_to_cudnn_type(note, backend_note) != CUDNN_STATUS_SUCCESS, - error_code_t::CUDNN_BACKEND_API_FAILED, - "Unexpected behaviour note provided."); + auto valid_note = (detail::convert_to_cudnn_type(note, backend_note) == CUDNN_STATUS_SUCCESS); for (auto i = 0u; i < engine_configs.size(); i++) { - if (std::find(behavior_notes[i].begin(), behavior_notes[i].end(), backend_note) != - behavior_notes[i].end()) { - filtered_indices[i] = true; - } + bool has_barred_note = std::find(behavior_notes[i].begin(), behavior_notes[i].end(), backend_note) != + numeric_notes[i].end(); + + barred_indices[i] = has_barred_note && valid_note ? !keep : keep; } } return {error_code_t::OK, ""}; @@ -363,25 +359,30 @@ class Execution_plan_list { max_workspace_allowed = workspace_allowed; } + void + set_barred_names(std::vector const& engine_names) { + barred_engine_names = engine_names; + } + EngineConfigList - get_filtered_engine_configs() { - EngineConfigList filtered_engine_configs; + get_barred_engine_configs() { + EngineConfigList barred_engine_configs; getLogger() << "[cudnn_frontend] INFO: " << " Filtering engine_configs ..." << engine_configs.size() << std::endl; for (auto i = 0u; i < engine_configs.size(); i++) { - if (filtered_indices[i] == false) { - filtered_engine_configs.push_back(engine_configs[i]); + if (barred_indices[i] == false) { + barred_engine_configs.push_back(engine_configs[i]); } } getLogger() << "[cudnn_frontend] INFO: " - << " Filtered engine_configs ..." << filtered_engine_configs.size() << std::endl; - return filtered_engine_configs; + << " barred engine_configs ..." << barred_engine_configs.size() << std::endl; + return barred_engine_configs; } error_t check_support(cudnnHandle_t handle) { for (auto i = 0u; i < engine_configs.size(); i++) { - if (filtered_indices[i]) { + if (barred_indices[i]) { getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan at position " << i << std::endl; continue; } @@ -395,14 +396,33 @@ class Execution_plan_list { if (fe_status.is_good()) { // Filter out execution plans with workspace greater than whats available from user if (execution_plans[i]->getWorkspaceSize() > max_workspace_allowed) { - filtered_indices[i] = true; - execution_plans[i] = nullptr; + barred_indices[i] = true; + execution_plans[i] = nullptr; getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan at position " << i << std::endl; continue; } + auto is_blocked = [](std::string const& full_name, + std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; + } + } + return false; + }; + + if (is_blocked(execution_plans[i]->getTag(), barred_engine_names)) { + getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan " << execution_plans[i]->getTag() + << std::endl; + barred_indices[i] = true; + execution_plans[i] = nullptr; + continue; + } + candidate = static_cast(i); - getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << std::endl; + getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << " " << execution_plans[i]->getTag() + << std::endl; return {error_code_t::OK, ""}; } @@ -427,7 +447,7 @@ class Execution_plan_list { error_t build_plan_at_index(cudnnHandle_t handle, int64_t index) { - RETURN_CUDNN_FRONTEND_ERROR_IF(filtered_indices[index] == true, + RETURN_CUDNN_FRONTEND_ERROR_IF(barred_indices[index] == true, error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "Chosen plan index has been deselected."); @@ -447,7 +467,7 @@ class Execution_plan_list { if (execution_plans[index]->getWorkspaceSize() <= max_workspace_allowed) { candidate = index; } else { - filtered_indices[index] = true; + barred_indices[index] = true; return {error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED, "[cudnn_frontend] Error: Workspace size is too large."}; } @@ -469,7 +489,7 @@ class Execution_plan_list { } for (auto i = 0u; i < engine_configs.size(); i++) { - if (filtered_indices[i]) { + if (barred_indices[i]) { getLogger() << "[cudnn_frontend] INFO: Skipping deselected engine plan at index " << i << std::endl; continue; } @@ -483,8 +503,26 @@ class Execution_plan_list { if (execution_plans[i]->getWorkspaceSize() > max_workspace_allowed) { getLogger() << "[cudnn_frontend] INFO: skipping plan since workspace violation. Requires " << execution_plans[i]->getWorkspaceSize() << std::endl; - filtered_indices[i] = true; - execution_plans[i] = nullptr; + barred_indices[i] = true; + execution_plans[i] = nullptr; + continue; + } + + auto is_blocked = [](std::string const& full_name, + std::vector const& blocked_names) -> bool { + for (auto const& blocked_name : blocked_names) { + if (full_name.find(blocked_name) != std::string::npos) { + return true; + } + } + return false; + }; + + if (is_blocked(execution_plans[i]->getTag(), barred_engine_names)) { + getLogger() << "[cudnn_frontend] INFO: Deselecting execution plan " << execution_plans[i]->getTag() + << std::endl; + barred_indices[i] = true; + execution_plans[i] = nullptr; continue; } // Only set the candidate the first time, as the order of iteration is from highest to lowest priority @@ -493,6 +531,9 @@ class Execution_plan_list { getLogger() << "[cudnn_frontend] INFO: Candidate set as " << i << std::endl; } + getLogger() << "[cudnn_frontend] INFO: Built plan at " << i << " " << execution_plans[i]->getTag() + << std::endl; + // Return from this function as first successfully built plan is found. if (policy == BuildPlanPolicy_t::HEURISTICS_CHOICE) { return {error_code_t::OK, ""}; @@ -608,4 +649,4 @@ class Execution_plan_list { }; } // namespace graph -} // namespace cudnn_frontend \ No newline at end of file +} // namespace cudnn_frontend diff --git a/include/cudnn_frontend/thirdparty/nlohmann/json.hpp b/include/cudnn_frontend/thirdparty/nlohmann/json.hpp index dc164b3f..2c9001a4 100644 --- a/include/cudnn_frontend/thirdparty/nlohmann/json.hpp +++ b/include/cudnn_frontend/thirdparty/nlohmann/json.hpp @@ -1,9 +1,9 @@ // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT /****************************************************************************\ @@ -34,10 +34,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -47,10 +47,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -59,7 +59,7 @@ #ifndef JSON_SKIP_LIBRARY_VERSION_CHECK #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) - #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 2 + #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3 #warning "Already included a different version of the library!" #endif #endif @@ -67,7 +67,7 @@ #define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) #define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) -#define NLOHMANN_JSON_VERSION_PATCH 2 // NOLINT(modernize-macro-to-enum) +#define NLOHMANN_JSON_VERSION_PATCH 3 // NOLINT(modernize-macro-to-enum) #ifndef JSON_DIAGNOSTICS #define JSON_DIAGNOSTICS 0 @@ -149,10 +149,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -172,10 +172,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -192,10 +192,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -208,10 +208,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -220,10 +220,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -233,10 +233,10 @@ // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -320,10 +320,10 @@ NLOHMANN_JSON_NAMESPACE_END // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-FileCopyrightText: 2016-2021 Evan Nemerson // SPDX-License-Identifier: MIT @@ -2485,6 +2485,14 @@ JSON_HEDLEY_DIAGNOSTIC_POP #endif #endif +#ifndef JSON_HAS_STATIC_RTTI + #if !defined(_HAS_STATIC_RTTI) || _HAS_STATIC_RTTI != 0 + #define JSON_HAS_STATIC_RTTI 1 + #else + #define JSON_HAS_STATIC_RTTI 0 + #endif +#endif + #ifdef JSON_HAS_CPP_17 #define JSON_INLINE_VARIABLE inline #else @@ -2750,6 +2758,9 @@ JSON_HEDLEY_DIAGNOSTIC_POP friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } + /*! @brief macro @def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE @@ -2759,11 +2770,13 @@ JSON_HEDLEY_DIAGNOSTIC_POP inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } + #define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...) \ inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } - // inspired from https://stackoverflow.com/a/26745591 // allows to call any std function as if (e.g. with begin): // using std::begin; begin(x); @@ -2926,10 +2939,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3001,10 +3014,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3043,10 +3056,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-FileCopyrightText: 2018 The Abseil Authors // SPDX-License-Identifier: MIT @@ -3217,10 +3230,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3229,14 +3242,15 @@ NLOHMANN_JSON_NAMESPACE_END #include // false_type, is_constructible, is_integral, is_same, true_type #include // declval #include // tuple +#include // char_traits // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3301,10 +3315,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3321,10 +3335,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -3345,10 +3359,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT #ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ @@ -3581,6 +3595,63 @@ struct actual_object_comparator template using actual_object_comparator_t = typename actual_object_comparator::type; +///////////////// +// char_traits // +///////////////// + +// Primary template of char_traits calls std char_traits +template +struct char_traits : std::char_traits +{}; + +// Explicitly define char traits for unsigned char since it is not standard +template<> +struct char_traits : std::char_traits +{ + using char_type = unsigned char; + using int_type = uint64_t; + + // Redefine to_int_type function + static int_type to_int_type(char_type c) noexcept + { + return static_cast(c); + } + + static char_type to_char_type(int_type i) noexcept + { + return static_cast(i); + } + + static constexpr int_type eof() noexcept + { + return static_cast(EOF); + } +}; + +// Explicitly define char traits for signed char since it is not standard +template<> +struct char_traits : std::char_traits +{ + using char_type = signed char; + using int_type = uint64_t; + + // Redefine to_int_type function + static int_type to_int_type(char_type c) noexcept + { + return static_cast(c); + } + + static char_type to_char_type(int_type i) noexcept + { + return static_cast(i); + } + + static constexpr int_type eof() noexcept + { + return static_cast(EOF); + } +}; + /////////////////// // is_ functions // /////////////////// @@ -3617,7 +3688,6 @@ template struct is_default_constructible> : conjunction...> {}; - template struct is_constructible : std::is_constructible {}; @@ -3633,7 +3703,6 @@ struct is_constructible> : is_default_constructible struct is_constructible> : is_default_constructible> {}; - template struct is_iterator_traits : std::false_type {}; @@ -4043,7 +4112,6 @@ struct value_in_range_of_impl2 } }; - template struct value_in_range_of_impl2 { @@ -4142,10 +4210,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -4290,7 +4358,6 @@ inline OutStringType concat(Args && ... args) NLOHMANN_JSON_NAMESPACE_END - NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail { @@ -4529,10 +4596,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -4553,10 +4620,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -5059,10 +5126,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -5079,10 +5146,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -5151,10 +5218,10 @@ template class iteration_proxy_value // older GCCs are a bit fussy and require explicit noexcept specifiers on defaulted functions iteration_proxy_value(iteration_proxy_value&&) noexcept(std::is_nothrow_move_constructible::value - && std::is_nothrow_move_constructible::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + && std::is_nothrow_move_constructible::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor,cppcoreguidelines-noexcept-move-operations) iteration_proxy_value& operator=(iteration_proxy_value&&) noexcept(std::is_nothrow_move_assignable::value - && std::is_nothrow_move_assignable::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor) + && std::is_nothrow_move_assignable::value) = default; // NOLINT(hicpp-noexcept-move,performance-noexcept-move-constructor,cppcoreguidelines-noexcept-move-operations) ~iteration_proxy_value() = default; /// dereference operator (needed for range-based for) @@ -5800,10 +5867,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -5912,10 +5979,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -6045,10 +6112,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -6071,10 +6138,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -6098,6 +6165,8 @@ NLOHMANN_JSON_NAMESPACE_END // #include +// #include + NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail @@ -6144,7 +6213,6 @@ class file_input_adapter std::FILE* m_file; }; - /*! Input adapter for a (caching) istream. Ignores a UFT Byte Order Mark at beginning of input. Does not support changing the underlying std::streambuf @@ -6218,16 +6286,16 @@ class iterator_input_adapter : current(std::move(first)), end(std::move(last)) {} - typename std::char_traits::int_type get_character() + typename char_traits::int_type get_character() { if (JSON_HEDLEY_LIKELY(current != end)) { - auto result = std::char_traits::to_int_type(*current); + auto result = char_traits::to_int_type(*current); std::advance(current, 1); return result; } - return std::char_traits::eof(); + return char_traits::eof(); } private: @@ -6243,7 +6311,6 @@ class iterator_input_adapter } }; - template struct wide_string_input_helper; @@ -6367,7 +6434,7 @@ struct wide_string_input_helper } }; -// Wraps another input apdater to convert wide character types into individual bytes. +// Wraps another input adapter to convert wide character types into individual bytes. template class wide_string_input_adapter { @@ -6412,7 +6479,6 @@ class wide_string_input_adapter std::size_t utf8_bytes_filled = 0; }; - template struct iterator_input_adapter_factory { @@ -6569,10 +6635,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -6714,7 +6780,6 @@ struct json_sax virtual ~json_sax() = default; }; - namespace detail { /*! @@ -7163,7 +7228,7 @@ class json_sax_dom_callback_parser if (ref_stack.empty()) { root = std::move(value); - return {true, &root}; + return {true, & root}; } // skip this value if we already decided to skip the parent @@ -7180,7 +7245,7 @@ class json_sax_dom_callback_parser if (ref_stack.back()->is_array()) { ref_stack.back()->m_data.m_value.array->emplace_back(std::move(value)); - return {true, &(ref_stack.back()->m_data.m_value.array->back())}; + return {true, & (ref_stack.back()->m_data.m_value.array->back())}; } // object @@ -7302,10 +7367,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -7326,6 +7391,8 @@ NLOHMANN_JSON_NAMESPACE_END // #include +// #include + NLOHMANN_JSON_NAMESPACE_BEGIN namespace detail @@ -7420,7 +7487,7 @@ class lexer : public lexer_base using number_float_t = typename BasicJsonType::number_float_t; using string_t = typename BasicJsonType::string_t; using char_type = typename InputAdapterType::char_type; - using char_int_type = typename std::char_traits::int_type; + using char_int_type = typename char_traits::int_type; public: using token_type = typename lexer_base::token_type; @@ -7527,7 +7594,7 @@ class lexer : public lexer_base for (auto range = ranges.begin(); range != ranges.end(); ++range) { get(); - if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) + if (JSON_HEDLEY_LIKELY(*range <= current && current <= *(++range))) // NOLINT(bugprone-inc-dec-in-conditions) { add(current); } @@ -7570,7 +7637,7 @@ class lexer : public lexer_base switch (get()) { // end of file while parsing string - case std::char_traits::eof(): + case char_traits::eof(): { error_message = "invalid string: missing closing quote"; return token_type::parse_error; @@ -8159,7 +8226,7 @@ class lexer : public lexer_base { case '\n': case '\r': - case std::char_traits::eof(): + case char_traits::eof(): case '\0': return true; @@ -8176,7 +8243,7 @@ class lexer : public lexer_base { switch (get()) { - case std::char_traits::eof(): + case char_traits::eof(): case '\0': { error_message = "invalid comment; missing closing '*/'"; @@ -8605,10 +8672,10 @@ class lexer : public lexer_base token_type scan_literal(const char_type* literal_text, const std::size_t length, token_type return_type) { - JSON_ASSERT(std::char_traits::to_char_type(current) == literal_text[0]); + JSON_ASSERT(char_traits::to_char_type(current) == literal_text[0]); for (std::size_t i = 1; i < length; ++i) { - if (JSON_HEDLEY_UNLIKELY(std::char_traits::to_char_type(get()) != literal_text[i])) + if (JSON_HEDLEY_UNLIKELY(char_traits::to_char_type(get()) != literal_text[i])) { error_message = "invalid literal"; return token_type::parse_error; @@ -8626,7 +8693,7 @@ class lexer : public lexer_base { token_buffer.clear(); token_string.clear(); - token_string.push_back(std::char_traits::to_char_type(current)); + token_string.push_back(char_traits::to_char_type(current)); } /* @@ -8634,7 +8701,7 @@ class lexer : public lexer_base This function provides the interface to the used input adapter. It does not throw in case the input reached EOF, but returns a - `std::char_traits::eof()` in that case. Stores the scanned characters + `char_traits::eof()` in that case. Stores the scanned characters for use in error messages. @return character read from the input @@ -8654,9 +8721,9 @@ class lexer : public lexer_base current = ia.get_character(); } - if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + if (JSON_HEDLEY_LIKELY(current != char_traits::eof())) { - token_string.push_back(std::char_traits::to_char_type(current)); + token_string.push_back(char_traits::to_char_type(current)); } if (current == '\n') @@ -8695,7 +8762,7 @@ class lexer : public lexer_base --position.chars_read_current_line; } - if (JSON_HEDLEY_LIKELY(current != std::char_traits::eof())) + if (JSON_HEDLEY_LIKELY(current != char_traits::eof())) { JSON_ASSERT(!token_string.empty()); token_string.pop_back(); @@ -8889,7 +8956,7 @@ class lexer : public lexer_base // end of input (the null byte is needed when parsing from // string literals) case '\0': - case std::char_traits::eof(): + case char_traits::eof(): return token_type::end_of_input; // error @@ -8907,7 +8974,7 @@ class lexer : public lexer_base const bool ignore_comments = false; /// the current character - char_int_type current = std::char_traits::eof(); + char_int_type current = char_traits::eof(); /// whether the next get() call should just return current bool next_unget = false; @@ -8941,10 +9008,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -9133,7 +9200,6 @@ static inline bool little_endianness(int num = 1) noexcept return *reinterpret_cast(&num) == 1; } - /////////////////// // binary reader // /////////////////// @@ -9151,7 +9217,7 @@ class binary_reader using binary_t = typename BasicJsonType::binary_t; using json_sax_t = SAX; using char_type = typename InputAdapterType::char_type; - using char_int_type = typename std::char_traits::int_type; + using char_int_type = typename char_traits::int_type; public: /*! @@ -9224,7 +9290,7 @@ class binary_reader get(); } - if (JSON_HEDLEY_UNLIKELY(current != std::char_traits::eof())) + if (JSON_HEDLEY_UNLIKELY(current != char_traits::eof())) { return sax->parse_error(chars_read, get_token_string(), parse_error::create(110, chars_read, exception_message(input_format, concat("expected end of input; last byte: 0x", get_token_string()), "value"), nullptr)); @@ -9307,7 +9373,7 @@ class binary_reader exception_message(input_format_t::bson, concat("string length must be at least 1, is ", std::to_string(len)), "string"), nullptr)); } - return get_string(input_format_t::bson, len - static_cast(1), result) && get() != std::char_traits::eof(); + return get_string(input_format_t::bson, len - static_cast(1), result) && get() != char_traits::eof(); } /*! @@ -9501,7 +9567,7 @@ class binary_reader switch (get_char ? get() : current) { // EOF - case std::char_traits::eof(): + case char_traits::eof(): return unexpect_eof(input_format_t::cbor, "value"); // Integer 0x00..0x17 (0..23) @@ -10276,7 +10342,7 @@ class binary_reader switch (get()) { // EOF - case std::char_traits::eof(): + case char_traits::eof(): return unexpect_eof(input_format_t::msgpack, "value"); // positive fixint @@ -11378,7 +11444,7 @@ class binary_reader { switch (prefix) { - case std::char_traits::eof(): // EOF + case char_traits::eof(): // EOF return unexpect_eof(input_format, "value"); case 'T': // true @@ -11823,7 +11889,7 @@ class binary_reader This function provides the interface to the used input adapter. It does not throw in case the input reached EOF, but returns a -'ve valued - `std::char_traits::eof()` in that case. + `char_traits::eof()` in that case. @return character read from the input */ @@ -11965,7 +12031,7 @@ class binary_reader JSON_HEDLEY_NON_NULL(3) bool unexpect_eof(const input_format_t format, const char* context) const { - if (JSON_HEDLEY_UNLIKELY(current == std::char_traits::eof())) + if (JSON_HEDLEY_UNLIKELY(current == char_traits::eof())) { return sax->parse_error(chars_read, "", parse_error::create(110, chars_read, exception_message(format, "unexpected end of input", context), nullptr)); @@ -12032,7 +12098,7 @@ class binary_reader InputAdapterType ia; /// the current character - char_int_type current = std::char_traits::eof(); + char_int_type current = char_traits::eof(); /// the number of characters read std::size_t chars_read = 0; @@ -12094,10 +12160,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -12443,13 +12509,25 @@ class parser m_lexer.get_token_string(), parse_error::create(101, m_lexer.get_position(), exception_message(token_type::uninitialized, "value"), nullptr)); } + case token_type::end_of_input: + { + if (JSON_HEDLEY_UNLIKELY(m_lexer.get_position().chars_read_total == 1)) + { + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), + "attempting to parse an empty input; check that your input string or stream contains the expected JSON", nullptr)); + } + return sax->parse_error(m_lexer.get_position(), + m_lexer.get_token_string(), + parse_error::create(101, m_lexer.get_position(), exception_message(token_type::literal_or_value, "value"), nullptr)); + } case token_type::uninitialized: case token_type::end_array: case token_type::end_object: case token_type::name_separator: case token_type::value_separator: - case token_type::end_of_input: case token_type::literal_or_value: default: // the last token was unexpected { @@ -12611,10 +12689,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -12624,10 +12702,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -12783,10 +12861,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -13545,10 +13623,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -13678,6 +13756,14 @@ NLOHMANN_JSON_NAMESPACE_END // #include // #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann +// SPDX-License-Identifier: MIT + #include // conditional, is_same @@ -13714,10 +13800,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -14709,10 +14795,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -14801,10 +14887,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -14827,10 +14913,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -16795,11 +16881,11 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2008-2009 Björn Hoehrmann -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -16820,11 +16906,11 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // // SPDX-FileCopyrightText: 2009 Florian Loitsch -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -18919,10 +19005,10 @@ NLOHMANN_JSON_NAMESPACE_END // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -19281,7 +19367,9 @@ NLOHMANN_JSON_NAMESPACE_END #if defined(JSON_HAS_CPP_17) - #include + #if JSON_HAS_STATIC_RTTI + #include + #endif #include #endif @@ -19411,7 +19499,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - ///////////////////// // container types // ///////////////////// @@ -19453,7 +19540,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - /// @brief returns the allocator associated with the container /// @sa https://json.nlohmann.me/api/basic_json/get_allocator/ static allocator_type get_allocator() @@ -19468,7 +19554,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec { basic_json result; - result["copyright"] = "(C) 2013-2022 Niels Lohmann"; + result["copyright"] = "(C) 2013-2023 Niels Lohmann"; result["name"] = "JSON for Modern C++"; result["url"] = "https://github.com/nlohmann/json"; result["version"]["string"] = @@ -19516,7 +19602,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec result["compiler"] = {{"family", "unknown"}, {"version", "unknown"}}; #endif - #if defined(_MSVC_LANG) result["compiler"]["c++"] = std::to_string(_MSVC_LANG); #elif defined(__cplusplus) @@ -19527,7 +19612,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec return result; } - /////////////////////////// // JSON value data types // /////////////////////////// @@ -19735,7 +19819,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec object = nullptr; // silence warning, see #821 if (JSON_HEDLEY_UNLIKELY(t == value_t::null)) { - JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 3.11.2", nullptr)); // LCOV_EXCL_LINE + JSON_THROW(other_error::create(500, "961c151d2e87f2686a955a9be24d316f1362bf21 3.11.3", nullptr)); // LCOV_EXCL_LINE } break; } @@ -20129,7 +20213,10 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec bool is_an_object = std::all_of(init.begin(), init.end(), [](const detail::json_ref& element_ref) { - return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[0].is_string(); + // The cast is to ensure op[size_type] is called, bearing in mind size_type may not be int; + // (many string types can be constructed from 0 via its null-pointer guise, so we get a + // broken call to op[key_type], the wrong semantics and a 4804 warning on Windows) + return element_ref->is_array() && element_ref->size() == 2 && (*element_ref)[static_cast(0)].is_string(); }); // adjust type if type deduction is not wanted @@ -20349,7 +20436,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec assert_invariant(); } - /////////////////////////////////////// // other constructors and destructor // /////////////////////////////////////// @@ -21107,7 +21193,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec #if defined(JSON_HAS_CPP_17) && (defined(__GNUC__) || (defined(_MSC_VER) && _MSC_VER >= 1910 && _MSC_VER <= 1914)) detail::negation>, #endif -#if defined(JSON_HAS_CPP_17) +#if defined(JSON_HAS_CPP_17) && JSON_HAS_STATIC_RTTI detail::negation>, #endif detail::is_detected_lazy @@ -21144,7 +21230,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - //////////////////// // element access // //////////////////// @@ -21859,7 +21944,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - //////////// // lookup // //////////// @@ -21977,7 +22061,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - /////////////// // iterators // /////////////// @@ -22116,7 +22199,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - ////////////// // capacity // ////////////// @@ -22238,7 +22320,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @} - /////////////// // modifiers // /////////////// @@ -22686,7 +22767,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec void swap(reference other) noexcept ( std::is_nothrow_move_constructible::value&& std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_constructible::value&& // NOLINT(cppcoreguidelines-noexcept-swap,performance-noexcept-swap) std::is_nothrow_move_assignable::value ) { @@ -22703,7 +22784,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec friend void swap(reference left, reference right) noexcept ( std::is_nothrow_move_constructible::value&& std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value&& + std::is_nothrow_move_constructible::value&& // NOLINT(cppcoreguidelines-noexcept-swap,performance-noexcept-swap) std::is_nothrow_move_assignable::value ) { @@ -22712,7 +22793,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @brief exchanges the values /// @sa https://json.nlohmann.me/api/basic_json/swap/ - void swap(array_t& other) // NOLINT(bugprone-exception-escape) + void swap(array_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap) { // swap only works for arrays if (JSON_HEDLEY_LIKELY(is_array())) @@ -22728,7 +22809,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @brief exchanges the values /// @sa https://json.nlohmann.me/api/basic_json/swap/ - void swap(object_t& other) // NOLINT(bugprone-exception-escape) + void swap(object_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap) { // swap only works for objects if (JSON_HEDLEY_LIKELY(is_object())) @@ -22744,7 +22825,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @brief exchanges the values /// @sa https://json.nlohmann.me/api/basic_json/swap/ - void swap(string_t& other) // NOLINT(bugprone-exception-escape) + void swap(string_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap) { // swap only works for strings if (JSON_HEDLEY_LIKELY(is_string())) @@ -22760,7 +22841,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec /// @brief exchanges the values /// @sa https://json.nlohmann.me/api/basic_json/swap/ - void swap(binary_t& other) // NOLINT(bugprone-exception-escape) + void swap(binary_t& other) // NOLINT(bugprone-exception-escape,cppcoreguidelines-noexcept-swap,performance-noexcept-swap) { // swap only works for strings if (JSON_HEDLEY_LIKELY(is_binary())) @@ -23225,7 +23306,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec #endif // JSON_NO_IO /// @} - ///////////////////// // deserialization // ///////////////////// @@ -23406,7 +23486,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec } } - JSON_PRIVATE_UNLESS_TESTED: ////////////////////// // member variables // @@ -23624,7 +23703,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec return from_cbor(ptr, ptr + len, strict, allow_exceptions, tag_handler); } - JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DEPRECATED_FOR(3.8.0, from_cbor(ptr, ptr + len)) static basic_json from_cbor(detail::span_input_adapter&& i, @@ -23748,7 +23826,6 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec return res ? result : basic_json(value_t::discarded); } - /// @brief create a JSON value from an input in BJData format /// @sa https://json.nlohmann.me/api/basic_json/from_bjdata/ template @@ -24029,7 +24106,7 @@ class basic_json // NOLINT(cppcoreguidelines-special-member-functions,hicpp-spec }; // wrapper for "remove" operation; remove value at ptr - const auto operation_remove = [this, &result](json_pointer & ptr) + const auto operation_remove = [this, & result](json_pointer & ptr) { // get reference to parent of JSON pointer ptr const auto last_path = ptr.back(); @@ -24392,7 +24469,11 @@ inline namespace json_literals /// @brief user-defined string literal for JSON values /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json operator "" _json(const char* s, std::size_t n) +#if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) + inline nlohmann::json operator ""_json(const char* s, std::size_t n) +#else + inline nlohmann::json operator "" _json(const char* s, std::size_t n) +#endif { return nlohmann::json::parse(s, s + n); } @@ -24400,7 +24481,11 @@ inline nlohmann::json operator "" _json(const char* s, std::size_t n) /// @brief user-defined string literal for JSON pointer /// @sa https://json.nlohmann.me/api/basic_json/operator_literal_json_pointer/ JSON_HEDLEY_NON_NULL(1) -inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +#if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) + inline nlohmann::json::json_pointer operator ""_json_pointer(const char* s, std::size_t n) +#else + inline nlohmann::json::json_pointer operator "" _json_pointer(const char* s, std::size_t n) +#endif { return nlohmann::json::json_pointer(std::string(s, n)); } @@ -24453,7 +24538,7 @@ struct less< ::nlohmann::detail::value_t> // do not remove the space after '<', /// @sa https://json.nlohmann.me/api/basic_json/std_swap/ NLOHMANN_BASIC_JSON_TPL_DECLARATION inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC_JSON_TPL& j2) noexcept( // NOLINT(readability-inconsistent-declaration-parameter-name, cert-dcl58-cpp) - is_nothrow_move_constructible::value&& // NOLINT(misc-redundant-expression) + is_nothrow_move_constructible::value&& // NOLINT(misc-redundant-expression,cppcoreguidelines-noexcept-swap,performance-noexcept-swap) is_nothrow_move_assignable::value) { j1.swap(j2); @@ -24464,17 +24549,22 @@ inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC } // namespace std #if JSON_USE_GLOBAL_UDLS - using nlohmann::literals::json_literals::operator "" _json; // NOLINT(misc-unused-using-decls,google-global-names-in-headers) - using nlohmann::literals::json_literals::operator "" _json_pointer; //NOLINT(misc-unused-using-decls,google-global-names-in-headers) + #if !defined(JSON_HEDLEY_GCC_VERSION) || JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) + using nlohmann::literals::json_literals::operator ""_json; // NOLINT(misc-unused-using-decls,google-global-names-in-headers) + using nlohmann::literals::json_literals::operator ""_json_pointer; //NOLINT(misc-unused-using-decls,google-global-names-in-headers) + #else + using nlohmann::literals::json_literals::operator "" _json; // NOLINT(misc-unused-using-decls,google-global-names-in-headers) + using nlohmann::literals::json_literals::operator "" _json_pointer; //NOLINT(misc-unused-using-decls,google-global-names-in-headers) + #endif #endif // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT @@ -24509,16 +24599,17 @@ inline void swap(nlohmann::NLOHMANN_BASIC_JSON_TPL& j1, nlohmann::NLOHMANN_BASIC #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM #undef JSON_HAS_THREE_WAY_COMPARISON #undef JSON_HAS_RANGES + #undef JSON_HAS_STATIC_RTTI #undef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON #endif // #include // __ _____ _____ _____ // __| | __| | | | JSON for Modern C++ -// | | |__ | | | | | | version 3.11.2 +// | | |__ | | | | | | version 3.11.3 // |_____|_____|_____|_|___| https://github.com/nlohmann/json // -// SPDX-FileCopyrightText: 2013-2022 Niels Lohmann +// SPDX-FileCopyrightText: 2013-2023 Niels Lohmann // SPDX-License-Identifier: MIT diff --git a/include/cudnn_frontend/utils/serialize.h b/include/cudnn_frontend/utils/serialize.h index 087f549f..22b57161 100644 --- a/include/cudnn_frontend/utils/serialize.h +++ b/include/cudnn_frontend/utils/serialize.h @@ -214,6 +214,21 @@ NLOHMANN_JSON_SERIALIZE_ENUM(Matmul_attributes::output_names, {Matmul_attributes::output_names::C, "C"}, }) +NLOHMANN_JSON_SERIALIZE_ENUM(Matmul_fp8_attributes::input_names, + { + {Matmul_fp8_attributes::input_names::A, "A"}, + {Matmul_fp8_attributes::input_names::B, "B"}, + {Matmul_fp8_attributes::input_names::Descale_A, "Descale_A"}, + {Matmul_fp8_attributes::input_names::Descale_B, "Descale_B"}, + {Matmul_fp8_attributes::input_names::Scale_C, "Scale_C"}, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(Matmul_fp8_attributes::output_names, + { + {Matmul_fp8_attributes::output_names::C, "C"}, + {Matmul_fp8_attributes::output_names::Amax_C, "Amax_C"}, + }) + NLOHMANN_JSON_SERIALIZE_ENUM(Pointwise_attributes::input_names, { {Pointwise_attributes::input_names::IN_0, "IN_0"}, @@ -297,6 +312,28 @@ NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_attributes::output_names, {SDPA_attributes::output_names::Stats, "Stats"}, {SDPA_attributes::output_names::RNG_DUMP, "RNG_DUMP"}}) +NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_attributes::input_names, + { + {SDPA_fp8_attributes::input_names::Q, "Q"}, + {SDPA_fp8_attributes::input_names::K, "K"}, + {SDPA_fp8_attributes::input_names::V, "V"}, + {SDPA_fp8_attributes::input_names::Attn_scale, "Attn_scale"}, + {SDPA_fp8_attributes::input_names::Descale_Q, "Descale_Q"}, + {SDPA_fp8_attributes::input_names::Descale_K, "Descale_K"}, + {SDPA_fp8_attributes::input_names::Descale_V, "Descale_V"}, + {SDPA_fp8_attributes::input_names::Descale_S, "Descale_S"}, + {SDPA_fp8_attributes::input_names::Scale_S, "Scale_S"}, + {SDPA_fp8_attributes::input_names::Scale_O, "Scale_O"}, + }) + +NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_fp8_attributes::output_names, + { + {SDPA_fp8_attributes::output_names::O, "O"}, + {SDPA_fp8_attributes::output_names::Stats, "Stats"}, + {SDPA_fp8_attributes::output_names::Amax_O, "Amax_O"}, + {SDPA_fp8_attributes::output_names::Amax_S, "Amax_S"}, + }) + NLOHMANN_JSON_SERIALIZE_ENUM(SDPA_backward_attributes::input_names, { {SDPA_backward_attributes::input_names::Q, "Q"}, @@ -352,22 +389,7 @@ from_json(const nlohmann::json& j, Tensor_attributes& ta) { ta.uid_assigned = j.at("uid_assigned").get(); if (ta.is_pass_by_value && !j["pass_by_value"].is_null()) { - switch (ta.data_type) { - case DataType_t::HALF: - ta.pass_by_value = j.at("pass_by_value").get(); - break; - case DataType_t::FLOAT: - ta.pass_by_value = j.at("pass_by_value").get(); - break; - case DataType_t::BFLOAT16: - ta.pass_by_value = j.at("pass_by_value").get(); - break; - case DataType_t::INT32: - ta.pass_by_value = j.at("pass_by_value").get(); - break; - default: - throw std::runtime_error("Unsupported data type for pass_by_value"); - } + ta.pass_by_value = j.at("pass_by_value"); } } diff --git a/include/cudnn_frontend_Resample.h b/include/cudnn_frontend_Resample.h index bdcc9e47..7c98befd 100644 --- a/include/cudnn_frontend_Resample.h +++ b/include/cudnn_frontend_Resample.h @@ -218,6 +218,7 @@ class ResampleDescBuilder_v8 { auto setSpatialDim(int64_t count, cudnnFraction_t const *arr) -> ResampleDescBuilder_v8 & { // TODO: check the provided array count against the stored spatial dimension count. + m_resampleDesc.spatialDim = count; std::copy(arr, arr + count, m_resampleDesc.windowDim); return *this; } diff --git a/include/cudnn_frontend_shim.h b/include/cudnn_frontend_shim.h index 1707d36b..f02dfc75 100644 --- a/include/cudnn_frontend_shim.h +++ b/include/cudnn_frontend_shim.h @@ -142,6 +142,11 @@ cuda_get_device_properties(cudaDeviceProp *prop, int device) { NV_FE_CALL_TO_CUDA(cuda_get_device_properties, cudaGetDeviceProperties, prop, device); } +inline cudaError_t +cuda_get_device(int *device) { + NV_FE_CALL_TO_CUDA(cuda_get_device, cudaGetDevice, device); +} + inline const char * cuda_get_error_string(cudaError_t error) { NV_FE_CALL_TO_CUDA(cuda_get_error_string, cudaGetErrorString, error); @@ -183,6 +188,28 @@ get_backend_version(void) { #endif } +namespace detail { + +inline std::string +convert_version_to_str(size_t const version) { + // The multiplier for major version pre-v9 and post-v9 are different. + size_t major = version / 10000; + size_t minor = (version / 100) % 100; + if (major == 0) { + major = version / 1000; + minor = (version / 100) % 10; + } + auto patch = version % 100; + + return std::to_string(major) + "." + std::to_string(minor) + "." + std::to_string(patch); +} +} // namespace detail + +inline std::string +get_backend_version_string() { + return detail::convert_version_to_str(get_backend_version()); +} + inline cudnnStatus_t create_descriptor(cudnnBackendDescriptorType_t descriptorType, cudnnBackendDescriptor_t *descriptor) { NV_FE_CALL_TO_BACKEND(create_descriptor, cudnnBackendCreateDescriptor, descriptorType, descriptor); diff --git a/include/cudnn_frontend_utils.h b/include/cudnn_frontend_utils.h index 43ee10f0..cf3ef1b9 100644 --- a/include/cudnn_frontend_utils.h +++ b/include/cudnn_frontend_utils.h @@ -92,26 +92,31 @@ struct nlohmann::adl_serializer { } }; -template -void -convert_from_json_to_variant(const nlohmann::json& j, std::variant& data) { - try { - data = j.get(); - } catch (...) { - // get will throw an error if incorrect type - } -} - -template -struct nlohmann::adl_serializer> { +template <> +struct nlohmann::adl_serializer> { static void - to_json(nlohmann::json& j, const std::variant& data) { - std::visit([&j](const auto& v) { j = v; }, data); + to_json(nlohmann::json& j, const std::variant& data) { + std::visit([&](const auto& v) { j = {{"index", data.index()}, {"value", v}}; }, data); } static void - from_json(const nlohmann::json& j, std::variant& data) { - (convert_from_json_to_variant(j, data), ...); + from_json(const nlohmann::json& j, std::variant& data) { + if (!j.is_object() || !j.contains("index") || !j.contains("value")) { + throw std::invalid_argument("Invalid JSON format for std::variant"); + } + + size_t type_index = j.at("index").get(); + if (type_index == 0) { + data = j.at("value").get(); + } else if (type_index == 1) { + data = j.at("value").get(); + } else if (type_index == 2) { + data = j.at("value").get(); + } else if (type_index == 3) { + data = j.at("value").get(); + } else { + throw std::out_of_range("Variant index out of range"); + } } }; @@ -155,6 +160,20 @@ struct nlohmann::adl_serializer> { } }; +// Specialization of nlohmann::adl_serializer for cudnnFraction_t +template <> +struct nlohmann::adl_serializer { + static void + to_json(json& j, const cudnnFraction_t& fraction) { + j = fraction.numerator; + } + + static void + from_json(const json& j, cudnnFraction_t& fraction) { + fraction.numerator = j; + } +}; + #include "cudnn_frontend_shim.h" #include "cudnn_backend_base.h" #include "cudnn_frontend_Logging.h" @@ -256,6 +275,12 @@ to_string(cudnnBackendNumericalNote_t note) { return std::string("CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13"); case CUDNN_NUMERICAL_NOTE_TYPE_COUNT: return std::string("CUDNN_NUMERICAL_NOTE_TYPE_COUNT"); + + // If none of the above cases hit, its definitely strict nan prop and should raise an error. +#if (CUDNN_VERSION >= 90100) + case CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP: + return std::string("CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP"); +#endif #ifndef NO_DEFAULT_IN_SWITCH default: return std::string("UNKNOWN_NUMERICAL_NOTE"); @@ -565,6 +590,7 @@ enum class NumericalNote_t { WINOGRAD_TILE_4x4, WINOGRAD_TILE_6x6, WINOGRAD_TILE_13x13, + STRICT_NAN_PROP, }; NLOHMANN_JSON_SERIALIZE_ENUM(NumericalNote_t, @@ -578,6 +604,7 @@ NLOHMANN_JSON_SERIALIZE_ENUM(NumericalNote_t, {NumericalNote_t::WINOGRAD_TILE_4x4, "WINOGRAD_TILE_4x4"}, {NumericalNote_t::WINOGRAD_TILE_6x6, "WINOGRAD_TILE_6x6"}, {NumericalNote_t::WINOGRAD_TILE_13x13, "WINOGRAD_TILE_13x13"}, + {NumericalNote_t::STRICT_NAN_PROP, "STRICT_NAN_PROP"}, }) enum class DataType_t { @@ -1179,6 +1206,14 @@ convert_to_cudnn_type(cudnn_frontend::NumericalNote_t const mode, cudnnBackendNu case NumericalNote_t::WINOGRAD_TILE_13x13: cudnn_mode = CUDNN_NUMERICAL_NOTE_WINOGRAD_TILE_13x13; return cudnnStatus_t::CUDNN_STATUS_SUCCESS; + case NumericalNote_t::STRICT_NAN_PROP: +#if (CUDNN_VERSION >= 90100) + NV_CUDNN_FE_DYNAMIC_CHECK_CUDNN_BACKEND_VERSION(90100, cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE); + cudnn_mode = CUDNN_NUMERICAL_NOTE_STRICT_NAN_PROP; + return cudnnStatus_t::CUDNN_STATUS_SUCCESS; +#else + return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; +#endif } return cudnnStatus_t::CUDNN_STATUS_INVALID_VALUE; } diff --git a/pyproject.toml b/pyproject.toml index edd73b2e..0f10f34b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ dynamic = ["version"] description = "CUDNN FrontEnd python library" readme = "README.md" requires-python = ">=3.7" -license = {file = "LICENSE.txt"} +license = {text = "NVIDIA Proprietary Software"} classifiers = [ "Programming Language :: Python :: 3", ] diff --git a/python/cudnn/__init__.py b/python/cudnn/__init__.py index 6d50f604..eaf49a1e 100644 --- a/python/cudnn/__init__.py +++ b/python/cudnn/__init__.py @@ -1,5 +1,6 @@ from ._compiled_module import ( backend_version + , backend_version_string , destroy_handle , norm_forward_phase , reduction_mode @@ -13,11 +14,12 @@ , heur_mode , pygraph , tensor + , cudnnGraphNotSupportedError ) from .datatypes import (_library_type, _is_torch_tensor) -__version__ = '1.2.1' +__version__ = '1.3.0' def _tensor( self, diff --git a/python/properties.cpp b/python/properties.cpp index 60d97a29..2e6dd99b 100644 --- a/python/properties.cpp +++ b/python/properties.cpp @@ -140,7 +140,8 @@ init_properties(py::module_& m) { .value("WINOGRAD", cudnn_frontend::NumericalNote_t::WINOGRAD) .value("WINOGRAD_TILE_4x4", cudnn_frontend::NumericalNote_t::WINOGRAD_TILE_4x4) .value("WINOGRAD_TILE_6x6", cudnn_frontend::NumericalNote_t::WINOGRAD_TILE_6x6) - .value("WINOGRAD_TILE_13x13", cudnn_frontend::NumericalNote_t::WINOGRAD_TILE_13x13); + .value("WINOGRAD_TILE_13x13", cudnn_frontend::NumericalNote_t::WINOGRAD_TILE_13x13) + .value("STRICT_NAN_PROP", cudnn_frontend::NumericalNote_t::STRICT_NAN_PROP); py::enum_(m, "behavior_note") .value("RUNTIME_COMPILATION", cudnn_frontend::BehaviorNote_t::RUNTIME_COMPILATION) @@ -149,4 +150,4 @@ init_properties(py::module_& m) { } } // namespace python_bindings -} // namespace cudnn_frontend \ No newline at end of file +} // namespace cudnn_frontend diff --git a/python/pycudnn.cpp b/python/pycudnn.cpp index 0ff5dd08..b43570d5 100644 --- a/python/pycudnn.cpp +++ b/python/pycudnn.cpp @@ -33,7 +33,7 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK: throw std::invalid_argument(error_msg); case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED: - throw std::runtime_error(error_msg); + throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str()); case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED: throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED: @@ -45,9 +45,9 @@ throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::st case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE: throw std::runtime_error(error_msg); case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT: - throw std::runtime_error(error_msg); + throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str()); case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED: - throw std::runtime_error(error_msg); + throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str()); case cudnn_frontend::error_code_t::HANDLE_ERROR: throw std::runtime_error(error_msg); } @@ -68,11 +68,14 @@ set_dlhandle_cudnn(std::intptr_t dlhandle) { PYBIND11_MODULE(_compiled_module, m) { m.def("backend_version", &cudnn_frontend::get_backend_version); + m.def("backend_version_string", &cudnn_frontend::get_backend_version_string); init_properties(m); init_pygraph_submodule(m); m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn); + + py::register_exception(m, "cudnnGraphNotSupportedError"); } } // namespace python_bindings diff --git a/python/pygraph/pygraph.cpp b/python/pygraph/pygraph.cpp index 0a88b66f..80c91268 100644 --- a/python/pygraph/pygraph.cpp +++ b/python/pygraph/pygraph.cpp @@ -591,6 +591,8 @@ init_pygraph_submodule(py::module_& m) { )pbdoc") .def("deselect_numeric_notes", &PyGraph::deselect_numeric_notes) .def("deselect_behavior_notes", &PyGraph::deselect_behavior_notes) + .def("select_numeric_notes", &PyGraph::select_numeric_notes) + .def("select_behavior_notes", &PyGraph::select_behavior_notes) .def("deselect_workspace_greater_than", &PyGraph::deselect_workspace_greater_than) .def("validate", &PyGraph::validate) .def("key", &PyGraph::key) diff --git a/python/pygraph/pygraph.h b/python/pygraph/pygraph.h index 7ac0a226..acfdb0a9 100644 --- a/python/pygraph/pygraph.h +++ b/python/pygraph/pygraph.h @@ -52,10 +52,9 @@ class PyGraph { .set_intermediate_data_type(intermediate_data_type) .set_io_data_type(io_data_type); - if(handle_.has_value()) { + if (handle_.has_value()) { handle = static_cast((void*)(handle_.value())); - } - else { + } else { cudnn_frontend::create_handle(&handle); is_handle_owner = true; } @@ -247,6 +246,7 @@ class PyGraph { cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); + // return [o, stats] std::array, 2> sdpa(std::shared_ptr& q, std::shared_ptr& k, @@ -264,6 +264,7 @@ class PyGraph { cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); + // return [dQ, dK, dV] std::array, 3> sdpa_backward(std::shared_ptr& q, std::shared_ptr& k, @@ -284,6 +285,48 @@ class PyGraph { cudnn_frontend::DataType_t const& compute_data_type, std::string const& name); + // return [o, stats, amax_s, amax_o] + std::array, 4> + sdpa_fp8(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& descale_q, + std::shared_ptr& descale_k, + std::shared_ptr& descale_v, + std::shared_ptr& descale_s, + std::shared_ptr& scale_s, + std::shared_ptr& scale_o, + bool const is_inference, + py::object const& attn_scale, + bool const use_causal_mask, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + + // return [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] + std::array, 7> + sdpa_fp8_backward(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& o, + std::shared_ptr& dO, + std::shared_ptr& stats, + std::shared_ptr& descale_q, + std::shared_ptr& descale_k, + std::shared_ptr& descale_v, + std::shared_ptr& descale_o, + std::shared_ptr& descale_dO, + std::shared_ptr& descale_s, + std::shared_ptr& descale_dP, + std::shared_ptr& scale_s, + std::shared_ptr& scale_dQ, + std::shared_ptr& scale_dK, + std::shared_ptr& scale_dV, + std::shared_ptr& scale_dP, + py::object const& attn_scale, + bool const use_causal_mask, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name); + void validate(); @@ -315,7 +358,22 @@ class PyGraph { execute(std::unordered_map var_pack, int64_t workspace, std::optional); void - execute_plan_at_index(std::unordered_map var_pack, int64_t workspace, int64_t index, std::optional); + execute_plan_at_index(std::unordered_map var_pack, + int64_t workspace, + int64_t index, + std::optional); + + void + select_numeric_notes(std::vector const& notes) { + graph.select_numeric_notes(notes); + return; + } + + void + select_behavior_notes(std::vector const& notes) { + graph.select_behavior_notes(notes); + return; + } void deselect_numeric_notes(std::vector const& notes) { diff --git a/python/pygraph/sdpa.cpp b/python/pygraph/sdpa.cpp index 3c78ff5e..7527f4d9 100644 --- a/python/pygraph/sdpa.cpp +++ b/python/pygraph/sdpa.cpp @@ -179,6 +179,108 @@ PyGraph::sdpa_backward(std::shared_ptr return {dQ, dK, dV}; } +std::array, 4> +PyGraph::sdpa_fp8(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& descale_q, + std::shared_ptr& descale_k, + std::shared_ptr& descale_v, + std::shared_ptr& descale_s, + std::shared_ptr& scale_s, + std::shared_ptr& scale_o, + bool const is_inference, + py::object const& attn_scale, + bool const use_causal_mask, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::SDPA_fp8_attributes() + .set_is_inference(is_inference) + .set_causal_mask(use_causal_mask) + .set_compute_data_type(compute_data_type) + .set_name(name); + + if (!attn_scale.is_none()) { + if (py::isinstance(attn_scale)) { + auto const attn_scale_value = attn_scale.cast(); + attributes.set_attn_scale(attn_scale_value); + } else { + auto const attn_scale_tensor = attn_scale.cast>(); + if (!attn_scale_tensor) { + throw std::runtime_error("attn_scale must be a cudnn_tensor or float."); + } + attributes.set_attn_scale(attn_scale_tensor); + } + } + + auto [o, stats, amax_s, amax_o] = + graph.sdpa_fp8(q, k, v, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, attributes); + return {o, stats, amax_s, amax_o}; +} + +std::array, 7> +PyGraph::sdpa_fp8_backward(std::shared_ptr& q, + std::shared_ptr& k, + std::shared_ptr& v, + std::shared_ptr& o, + std::shared_ptr& dO, + std::shared_ptr& stats, + std::shared_ptr& descale_q, + std::shared_ptr& descale_k, + std::shared_ptr& descale_v, + std::shared_ptr& descale_o, + std::shared_ptr& descale_dO, + std::shared_ptr& descale_s, + std::shared_ptr& descale_dP, + std::shared_ptr& scale_s, + std::shared_ptr& scale_dQ, + std::shared_ptr& scale_dK, + std::shared_ptr& scale_dV, + std::shared_ptr& scale_dP, + py::object const& attn_scale, + bool const use_causal_mask, + cudnn_frontend::DataType_t const& compute_data_type, + std::string const& name) { + auto attributes = cudnn_frontend::graph::SDPA_fp8_backward_attributes() + .set_causal_mask(use_causal_mask) + .set_compute_data_type(compute_data_type) + .set_name(name); + + if (!attn_scale.is_none()) { + if (py::isinstance(attn_scale)) { + auto const attn_scale_value = attn_scale.cast(); + attributes.set_attn_scale(attn_scale_value); + } else { + auto const attn_scale_tensor = attn_scale.cast>(); + if (!attn_scale_tensor) { + throw std::runtime_error("attn_scale must be a cudnn_tensor or float."); + } + attributes.set_attn_scale(attn_scale_tensor); + } + } + + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = graph.sdpa_fp8_backward(q, + k, + v, + o, + dO, + stats, + descale_q, + descale_k, + descale_v, + descale_o, + descale_dO, + descale_s, + descale_dP, + scale_s, + scale_dQ, + scale_dK, + scale_dV, + scale_dP, + attributes); + return {dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP}; +} + void init_pygraph_sdpa_submodule(py::class_& m) { m.def("sdpa", @@ -218,30 +320,30 @@ init_pygraph_sdpa_submodule(py::class_& m) { name (Optional[str]): The name of the operation. Returns: - o (cudnn_tensor): The result of scaled dot-product flash attention. + o (cudnn_tensor): The output data. stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. - )pbdoc") - .def("sdpa_backward", - &PyGraph::sdpa_backward, - py::arg("q"), - py::arg("k"), - py::arg("v"), - py::arg("o"), - py::arg("dO"), - py::arg("stats"), - py::arg_v("attn_scale", py::none()), - py::arg_v("bias", nullptr), - py::arg_v("dBias", nullptr), - py::arg_v("use_alibi_mask", false), - py::arg_v("use_padding_mask", false), - py::arg_v("seq_len_q", nullptr), - py::arg_v("seq_len_kv", nullptr), - py::arg_v("use_causal_mask", false), - py::arg_v("dropout", py::none()), - py::arg_v("rng_dump", nullptr), - py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), - py::arg_v("name", ""), - R"pbdoc( + )pbdoc"); + m.def("sdpa_backward", + &PyGraph::sdpa_backward, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("o"), + py::arg("dO"), + py::arg("stats"), + py::arg_v("attn_scale", py::none()), + py::arg_v("bias", nullptr), + py::arg_v("dBias", nullptr), + py::arg_v("use_alibi_mask", false), + py::arg_v("use_padding_mask", false), + py::arg_v("seq_len_q", nullptr), + py::arg_v("seq_len_kv", nullptr), + py::arg_v("use_causal_mask", false), + py::arg_v("dropout", py::none()), + py::arg_v("rng_dump", nullptr), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( Compute the key, query, value gradients of scaled dot product attention. Args: @@ -264,9 +366,110 @@ init_pygraph_sdpa_submodule(py::class_& m) { name (Optional[str]): The name of the operation. Returns: - dQ (cudnn_tensor): The query gradient tensor of scaled dot-product flash attention. - dK (cudnn_tensor): The key gradient tensor of scaled dot-product flash attention. - dV (cudnn_tensor): The value gradient tensor of scaled dot-product flash attention. + dQ (cudnn_tensor): The query gradient data. + dK (cudnn_tensor): The key gradient data. + dV (cudnn_tensor): The value gradient data. + )pbdoc"); + m.def("sdpa_fp8", + &PyGraph::sdpa_fp8, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("descale_q"), + py::arg("descale_k"), + py::arg("descale_v"), + py::arg("descale_s"), + py::arg("scale_s"), + py::arg("scale_o"), + py::arg("is_inference"), + py::arg_v("attn_scale", py::none()), + py::arg_v("use_causal_mask", false), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Perform scaled dot product attention with fp8 datatype inputs and outputs. + + Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + descale_q (cudnn_tensor): Descale factor for query. + descale_k (cudnn_tensor): Descale factor for key. + descale_v (cudnn_tensor): Descale factor for value. + descale_s (cudnn_tensor): Descale factor for S tensor. + scale_s (cudnn_tensor): Scale factor for S tensor. + scale_o (cudnn_tensor): Scale factor for output. + is_inference (bool): Whether it is an inference step or training step. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + + Returns: + o (cudnn_tensor): The output data. + stats (Optional[cudnn_tensor]): The softmax statistics in case the operation is in a training step. + amax_s (cudnn_tensor): The absolute maximum of S tensor. + amax_o (cudnn_tensor): The absolute maximum of output tensor. + )pbdoc"); + m.def("sdpa_fp8_backward", + &PyGraph::sdpa_fp8_backward, + py::arg("q"), + py::arg("k"), + py::arg("v"), + py::arg("o"), + py::arg("dO"), + py::arg("stats"), + py::arg("descale_q"), + py::arg("descale_k"), + py::arg("descale_v"), + py::arg("descale_o"), + py::arg("descale_dO"), + py::arg("descale_s"), + py::arg("descale_dP"), + py::arg("scale_s"), + py::arg("scale_dQ"), + py::arg("scale_dK"), + py::arg("scale_dV"), + py::arg("scale_dP"), + py::arg_v("attn_scale", py::none()), + py::arg_v("use_causal_mask", false), + py::arg_v("compute_data_type", cudnn_frontend::DataType_t::NOT_SET), + py::arg_v("name", ""), + R"pbdoc( + Compute the key, query, value gradients of scaled dot product attention with fp8 datatype inputs and outputs. + + Args: + q (cudnn_tensor): The query data. + k (cudnn_tensor): The key data. + v (cudnn_tensor): The value data. + o (cudnn_tensor): The output data. + dO (cudnn_tensor): The output gradient data. + stats (cudnn_tensor): The softmax statistics in case the operation is in a training step. + descale_q (cudnn_tensor): Descale factor for query. + descale_k (cudnn_tensor): Descale factor for key. + descale_v (cudnn_tensor): Descale factor for value. + descale_o (cudnn_tensor): Descale factor for output. + descale_dO (cudnn_tensor): Descale factor for output gradient. + descale_s (cudnn_tensor): Descale factor for S tensor. + descale_dP (cudnn_tensor): Descale factor for P gradient tensor. + scale_s (cudnn_tensor): Scale factor for S tensor. + scale_dQ (cudnn_tensor): Scale factor for query gradient. + scale_dK (cudnn_tensor): Scale factor for key gradient. + scale_dV (cudnn_tensor): Scale factor for value gradient. + scale_dP (cudnn_tensor): Scale factor for dP gradient. + attn_scale (Optional[Union[float, cudnn_tensor]]): The scale factor for attention. Default is None. + use_causal_mask (Optional[bool]): Whether to use causal mask. Default is False. + compute_data_type (Optional[cudnn.data_type]): The data type for computation. Default is NOT_SET. + name (Optional[str]): The name of the operation. + + Returns: + dQ (cudnn_tensor): The query gradient data. + dK (cudnn_tensor): The key gradient data. + dV (cudnn_tensor): The value gradient data. + amax_dQ (cudnn_tensor): The absolute maximum of query gradient tensor. + amax_dK (cudnn_tensor): The absolute maximum of key gradient tensor. + amax_dV (cudnn_tensor): The absolute maximum of value gradient tensor. + amax_dP (cudnn_tensor): The absolute maximum of dP tensor. )pbdoc"); m.attr("scaled_dot_product_flash_attention") = m.attr("sdpa"); m.attr("scaled_dot_product_flash_attention_backward") = m.attr("sdpa_backward"); diff --git a/requirements.txt b/requirements.txt index abe803ac..6e11de1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ jupyter numpy pybind11[global] pytest -pytest-xdist \ No newline at end of file +pytest-xdist +looseversion \ No newline at end of file diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 79eb226c..dfb316b4 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -31,6 +31,7 @@ add_executable( cpp/serialization.cpp cpp/autotuning.cpp cpp/pointwise.cpp + cpp/resample.cpp legacy_samples/conv_sample.cpp legacy_samples/resnet_test_list.cpp @@ -59,6 +60,7 @@ if (MSVC) /wd4458 # local hides class member (currently a problem for all inline setters) /wd4505 # unreferenced function with internal linkage has been removed /wd4101 /wd4189 # unreferenced local + /bigobj # increase number of sections in .Obj file ) else() target_compile_options( @@ -74,6 +76,7 @@ target_link_libraries( samples cudnn_frontend + _cudnn_frontend_pch Catch2::Catch2WithMain CUDNN::cudnn_all diff --git a/samples/cpp/matmuls.cpp b/samples/cpp/matmuls.cpp index dde5ab90..ed63c7d6 100644 --- a/samples/cpp/matmuls.cpp +++ b/samples/cpp/matmuls.cpp @@ -75,9 +75,11 @@ TEST_CASE("Matmul", "[matmul][graph]") { REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + graph.deselect_engines({"eng4_"}); + REQUIRE(graph.check_support(handle).is_good()); - REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::ALL).is_good()); // Run cudnn graph Surface C_gpu(b * m * n, false); diff --git a/samples/cpp/mha.cpp b/samples/cpp/mha.cpp index 5d676b95..fd8cc338 100644 --- a/samples/cpp/mha.cpp +++ b/samples/cpp/mha.cpp @@ -565,5 +565,493 @@ TEST_CASE("Flash backward", "[graph][mha][flash][backward]") { checkCudaErr(cudaDeviceSynchronize()); + cudnnDestroy(handle); +} + +TEST_CASE("sdpa_fp8_fprop", "[graph][mha][fp8][forward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h = 2; // head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + bool is_inference = false; + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); + + auto QKV_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd + auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bhsd + + auto Q = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + auto K = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + auto V = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(Q_dQ_O_dO_dims).set_stride(QKV_strides)); + + float attn_scale = 0.123f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_o = mha_graph.tensor_like(descale_q, "Scale_O"); + + auto sdpa_fp8_options = fe::graph::SDPA_fp8_attributes() + .set_name("sdpa_fp8") + .set_is_inference(is_inference) + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + auto [O, Stats, Amax_S, Amax_O] = + mha_graph.sdpa_fp8(Q, K, V, descale_q, descale_k, descale_v, descale_s, scale_s, scale_o, sdpa_fp8_options); + + O->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + Amax_O->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_S->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + // Check that Stats tensor is real, which is only when its training step + if (is_inference) { + REQUIRE(Stats == nullptr); + } else { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT); + } + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + auto plans = mha_graph.create_execution_plans({fe::HeurMode_t::A}); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + //// Build variant pack + Surface qkvTensor(b * s * 3 * h * d, false); + Surface oTensor(b * s * h * d, false); + void* devPtrQ = qkvTensor.devPtr; + void* devPtrK = (qkvTensor.devPtr + h * d); + void* devPtrV = (qkvTensor.devPtr + 2 * h * d); + void* devPtrO = oTensor.devPtr; + + Surface descale_Q_Tensor(1, false); + Surface descale_K_Tensor(1, false); + Surface descale_V_Tensor(1, false); + Surface descale_S_Tensor(1, false); + Surface scale_S_Tensor(1, false); + Surface scale_O_Tensor(1, false); + Surface Amax_S_Tensor(1, false); + Surface Amax_O_Tensor(1, false); + + std::unordered_map, void*> variant_pack = { + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, devPtrO}, + {descale_q, descale_Q_Tensor.devPtr}, + {descale_k, descale_K_Tensor.devPtr}, + {descale_v, descale_V_Tensor.devPtr}, + {descale_s, descale_S_Tensor.devPtr}, + {scale_s, scale_S_Tensor.devPtr}, + {scale_o, scale_O_Tensor.devPtr}, + {Amax_S, Amax_S_Tensor.devPtr}, + {Amax_O, Amax_O_Tensor.devPtr}}; + + Surface stats_tensor(b * h * s * 1, false); + if (is_inference == false) { + variant_pack[Stats] = stats_tensor.devPtr; + } + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} + +TEST_CASE("sdpa_fp8_bprop", "[graph][mha][fp8][backward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h = 2; // head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + // bs3hd + auto Q_dQ_O_dO_dims = std::vector({b, h, s, d}); + // QKV_strides + auto Q_dQ_strides = std::vector({s * 3 * h * d, d, 3 * h * d, 1}); // bs3hd + + auto Q_K_V_dQ_dK_dV_bulk_strides = std::vector({s * 3 * h * d, 3 * h * d, h * d, d, 1}); + + auto O_dO_strides = std::vector({s * h * d, d, h * d, 1}); // bshd + + auto K_V_dK_dV_dims{Q_dQ_O_dO_dims}; + auto K_V_dK_dV_strides{Q_dQ_strides}; + + auto MZ_OdO_dims = std::vector({b, h, s, 1}); + auto MZ_OdO_strides = std::vector({h * s, s, 1, 1}); + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto Q = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("Q").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto K = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("K").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto V = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("V").set_dim(K_V_dK_dV_dims).set_stride(K_V_dK_dV_strides)); + auto O = + mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); + auto dO = mha_graph.tensor( + fe::graph::Tensor_attributes().set_name("dO").set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides)); + auto Stats = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(MZ_OdO_dims) + .set_stride(MZ_OdO_strides) + .set_data_type(fe::DataType_t::FLOAT)); + + float attn_scale = 0.123f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); + + // options/attributes + auto sdpa_fp8_backwards_options = fe::graph::SDPA_fp8_backward_attributes() + .set_name("sdpa_fp8_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale); + + // output + auto [dQ, dK, dV, Amax_dQ, Amax_dK, Amax_dV, Amax_dP] = mha_graph.sdpa_fp8_backward(Q, + K, + V, + O, + dO, + Stats, + descale_q, + descale_k, + descale_v, + descale_o, + descale_dO, + descale_s, + descale_dP, + scale_s, + scale_dQ, + scale_dK, + scale_dV, + scale_dP, + sdpa_fp8_backwards_options); + + dQ->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + dK->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + dV->set_output(true).set_dim(Q_dQ_O_dO_dims).set_stride(O_dO_strides); + Amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + Amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + // Surfaces + auto Q_K_V_dQ_dK_dV_bulk_dims{b * s * 3 * h * d}; + auto dO_O_dims{b * s * h * d}; + Surface qkvTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; + void* devPtrQ{qkvTensor.devPtr}; + void* devPtrK{qkvTensor.devPtr + h * d}; + void* devPtrV{qkvTensor.devPtr + 2 * h * d}; + + Surface dQdKdVTensor{Q_K_V_dQ_dK_dV_bulk_dims, false}; + void* devPtrdQ{dQdKdVTensor.devPtr}; + void* devPtrdK{dQdKdVTensor.devPtr + h * d}; + void* devPtrdV{dQdKdVTensor.devPtr + 2 * h * d}; + + Surface dOTensor{dO_O_dims, false}; + Surface OTensor{dO_O_dims, false}; + + Surface descale_Q_Tensor{1, false}; + Surface descale_K_Tensor{1, false}; + Surface descale_V_Tensor{1, false}; + Surface descale_S_Tensor{1, false}; + Surface descale_dP_Tensor{1, false}; + Surface descale_dO_Tensor{1, false}; + Surface descale_O_Tensor{1, false}; + + Surface scale_S_Tensor{1, false}; + Surface scale_dQ_Tensor{1, false}; + Surface scale_dK_Tensor{1, false}; + Surface scale_dV_Tensor{1, false}; + Surface scale_dP_Tensor{1, false}; + + Surface AMax_dQ_Tensor{1, false}; + Surface AMax_dK_Tensor{1, false}; + Surface AMax_dV_Tensor{1, false}; + Surface AMax_dP_Tensor{1, false}; + + Surface StatsTensor(b * h * s * 1, false); + + // Variant pack + std::unordered_map, void*> variant_pack{ + {Q, devPtrQ}, + {K, devPtrK}, + {V, devPtrV}, + {O, OTensor.devPtr}, + {dO, dOTensor.devPtr}, + {dQ, devPtrdQ}, + {dK, devPtrdK}, + {dV, devPtrdV}, + {descale_q, descale_Q_Tensor.devPtr}, + {descale_k, descale_K_Tensor.devPtr}, + {descale_v, descale_V_Tensor.devPtr}, + {descale_o, descale_O_Tensor.devPtr}, + {descale_dO, descale_dO_Tensor.devPtr}, + {descale_s, descale_S_Tensor.devPtr}, + {descale_dP, descale_dP_Tensor.devPtr}, + {scale_s, scale_S_Tensor.devPtr}, + {scale_dQ, scale_dQ_Tensor.devPtr}, + {scale_dK, scale_dK_Tensor.devPtr}, + {scale_dV, scale_dV_Tensor.devPtr}, + {scale_dP, scale_dP_Tensor.devPtr}, + {Stats, StatsTensor.devPtr}, + {Amax_dQ, AMax_dQ_Tensor.devPtr}, + {Amax_dK, AMax_dK_Tensor.devPtr}, + {Amax_dV, AMax_dV_Tensor.devPtr}, + {Amax_dP, AMax_dP_Tensor.devPtr}}; + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + + cudnnDestroy(handle); +} + +TEST_CASE("sdpa_fp8_gqa_bprop", "[graph][mha][fp8][backward]") { + namespace fe = cudnn_frontend; + +#if CUDART_VERSION < 12000 + SKIP("Test requires cuda toolkit 12.0 or above"); + return; +#endif + + int64_t b = 2; // batch size + int64_t h_qo = 12; // query/output head dim + int64_t h_kv = 4; // key/value head dim + int64_t s = 512; // q,k,v tensor is padded to this seq length + int64_t d = 128; // hidden dim + + // construct graph + std::vector qo_dim = {b, h_qo, s, d}; + std::vector kv_dim = {b, h_kv, s, d}; + std::vector qo_stride = {s * h_qo * d, d, h_qo * d, 1}; // bshd + std::vector kv_stride = {s * h_kv * d, d, h_kv * d, 1}; // bshd + + std::vector stats_dim = {b, h_qo, s, 1}; + std::vector stats_stride = {h_qo * s, s, 1, 1}; + + fe::graph::Graph mha_graph; + mha_graph.set_io_data_type(fe::DataType_t::FP8_E4M3) + .set_intermediate_data_type(fe::DataType_t::FLOAT) + .set_compute_data_type(fe::DataType_t::FLOAT); + + auto q = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("Q").set_dim(qo_dim).set_stride(qo_stride)); + auto k = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("K").set_dim(kv_dim).set_stride(kv_stride)); + auto v = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("V").set_dim(kv_dim).set_stride(kv_stride)); + auto o = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("O").set_dim(qo_dim).set_stride(qo_stride)); + auto dO = mha_graph.tensor(fe::graph::Tensor_attributes().set_name("dO").set_dim(qo_dim).set_stride(qo_stride)); + auto stats = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Stats") + .set_dim(stats_dim) + .set_stride(stats_stride) + .set_data_type(fe::DataType_t::FLOAT)); + + float attn_scale = 0.125f; + + auto descale_q = mha_graph.tensor(fe::graph::Tensor_attributes() + .set_name("Descale_Q") + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + auto descale_k = mha_graph.tensor_like(descale_q, "Descale_K"); + auto descale_v = mha_graph.tensor_like(descale_q, "Descale_V"); + auto descale_s = mha_graph.tensor_like(descale_q, "Descale_S"); + auto descale_o = mha_graph.tensor_like(descale_q, "Descale_O"); + auto descale_dO = mha_graph.tensor_like(descale_q, "Descale_dO"); + auto descale_dP = mha_graph.tensor_like(descale_q, "Descale_dP"); + + auto scale_s = mha_graph.tensor_like(descale_q, "Scale_S"); + auto scale_dP = mha_graph.tensor_like(descale_q, "Scale_dP"); + auto scale_dQ = mha_graph.tensor_like(descale_q, "Scale_dQ"); + auto scale_dK = mha_graph.tensor_like(descale_q, "Scale_dK"); + auto scale_dV = mha_graph.tensor_like(descale_q, "Scale_dV"); + + // clang-format off + auto [dQ, dK, dV, amax_dQ, amax_dK, amax_dV, amax_dP] = mha_graph.sdpa_fp8_backward( + q, k, v, o, dO, stats, + descale_q, descale_k, descale_v, descale_o, descale_dO, descale_s, descale_dP, + scale_s, scale_dQ, scale_dK, scale_dV, scale_dP, + fe::graph::SDPA_fp8_backward_attributes().set_name("sdpa_fp8_backward") + .set_causal_mask(true) + .set_attn_scale(attn_scale) + ); + // clang-format on + + dQ->set_output(true).set_dim(qo_dim).set_stride(qo_stride); + dK->set_output(true).set_dim(kv_dim).set_stride(kv_stride); + dV->set_output(true).set_dim(kv_dim).set_stride(kv_stride); + amax_dQ->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dK->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dV->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + amax_dP->set_output(true).set_dim({1, 1, 1, 1}).set_data_type(fe::DataType_t::FLOAT); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + auto status = mha_graph.validate(); + if ((cudnnGetVersion() >= 90100) && check_device_arch_newer_than("hopper")) { + REQUIRE(status.is_good()); + } else { + REQUIRE(status.get_code() == fe::error_code_t::GRAPH_NOT_SUPPORTED); + cudnnDestroy(handle); + return; + } + + REQUIRE(mha_graph.build_operation_graph(handle).is_good()); + REQUIRE(mha_graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(mha_graph.check_support(handle).is_good()); + REQUIRE(mha_graph.build_plans(handle).is_good()); + + // Surfaces that alllocate GPU memory + Surface q_gpu(b * s * h_qo * d, false); + Surface k_gpu(b * s * h_kv * d, false); + Surface v_gpu(b * s * h_kv * d, false); + Surface o_gpu(b * s * h_qo * d, false); + + Surface stats_gpu(b * h_qo * s * 1, false); + + Surface dQ_gpu(b * s * h_qo * d, false); + Surface dK_gpu(b * s * h_kv * d, false); + Surface dV_gpu(b * s * h_kv * d, false); + Surface dO_gpu(b * s * h_qo * d, false); + + Surface descale_q_gpu(1, false); + Surface descale_k_gpu(1, false); + Surface descale_v_gpu(1, false); + Surface descale_o_gpu(1, false); + Surface descale_s_gpu(1, false); + Surface descale_dP_gpu(1, false); + Surface descale_dO_gpu(1, false); + + Surface scale_s_gpu(1, false); + Surface scale_dQ_gpu(1, false); + Surface scale_dK_gpu(1, false); + Surface scale_dV_gpu(1, false); + Surface scale_dP_gpu(1, false); + + Surface amax_dQ_gpu(1, false); + Surface amax_dK_gpu(1, false); + Surface amax_dV_gpu(1, false); + Surface amax_dP_gpu(1, false); + + // Variant pack + std::unordered_map, void*> variant_pack{ + {q, q_gpu.devPtr}, + {k, k_gpu.devPtr}, + {v, v_gpu.devPtr}, + {o, o_gpu.devPtr}, + + {dQ, dQ_gpu.devPtr}, + {dK, dK_gpu.devPtr}, + {dV, dV_gpu.devPtr}, + {dO, dO_gpu.devPtr}, + + {stats, stats_gpu.devPtr}, + + {descale_q, descale_q_gpu.devPtr}, + {descale_k, descale_k_gpu.devPtr}, + {descale_v, descale_v_gpu.devPtr}, + {descale_o, descale_o_gpu.devPtr}, + {descale_s, descale_s_gpu.devPtr}, + {descale_dP, descale_dP_gpu.devPtr}, + {descale_dO, descale_dO_gpu.devPtr}, + + {scale_s, scale_s_gpu.devPtr}, + {scale_dQ, scale_dQ_gpu.devPtr}, + {scale_dK, scale_dK_gpu.devPtr}, + {scale_dV, scale_dV_gpu.devPtr}, + {scale_dP, scale_dP_gpu.devPtr}, + + {amax_dQ, amax_dQ_gpu.devPtr}, + {amax_dK, amax_dK_gpu.devPtr}, + {amax_dV, amax_dV_gpu.devPtr}, + {amax_dP, amax_dP_gpu.devPtr}}; + + Surface workspace(mha_graph.get_workspace_size(), false); + REQUIRE(mha_graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudaErr(cudaDeviceSynchronize()); + cudnnDestroy(handle); } \ No newline at end of file diff --git a/samples/cpp/resample.cpp b/samples/cpp/resample.cpp new file mode 100644 index 00000000..2931d959 --- /dev/null +++ b/samples/cpp/resample.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ + +#include +#include "../utils/helpers.h" + +#include + +TEST_CASE("Resample Max Pooling NHWC Inference", "[resample][pooling][max][graph]") { + namespace fe = cudnn_frontend; + + // This example shows running max pooling graphs when in inference mode. + // See details about support surface in + // https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html#resamplefwd + + constexpr int N = 8; + constexpr int H = 56; + constexpr int W = 56; + constexpr int C = 8; + + fe::graph::Graph graph{}; + + graph.set_io_data_type(fe::DataType_t::HALF).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes().set_dim({N, C, H, W}).set_stride({H * W * C, 1, W * C, C})); + + auto [Y, Index] = graph.resample(X, + fe::graph::Resample_attributes() + .set_is_inference(true) + .set_resampling_mode(fe::ResampleMode_t::MAXPOOL) + .set_padding_mode(fe::PaddingMode_t::NEG_INF_PAD) + .set_window({2, 3}) + .set_stride({4, 5}) + .set_pre_padding({2, 3}) + .set_post_padding({4, 5})); + + Y->set_output(true); + assert(Index == nullptr); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.check_support(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(N * H * W * C, false); + Surface Y_gpu(N * H * W * C, false); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + Surface workspace(graph.get_workspace_size(), false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudnnErr(cudnnDestroy(handle)); +} + +TEST_CASE("Resample Max Pooling NHWC Training", "[resample][pooling][max][graph]") { + namespace fe = cudnn_frontend; + + // This example shows running NHWC max pooling graphs. + // Support for NHWC max pooling has a fast path which can dump index tensor from forward pass. + // This mean backward pass to skip reading full X tensor and instead just use this index tensor. + // See details about support surface and index tensor in + // https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html#resamplefwd + + constexpr int N = 8; + constexpr int H = 56; + constexpr int W = 56; + constexpr int C = 8; + + fe::graph::Graph graph{}; + + graph.set_io_data_type(fe::DataType_t::HALF).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes().set_dim({N, C, H, W}).set_stride({H * W * C, 1, W * C, C})); + + auto [Y, Index] = graph.resample(X, + fe::graph::Resample_attributes() + .set_is_inference(false) + .set_resampling_mode(fe::ResampleMode_t::MAXPOOL) + .set_padding_mode(fe::PaddingMode_t::NEG_INF_PAD) + .set_window({2, 3}) + .set_stride({4, 5}) + .set_pre_padding({2, 3}) + .set_post_padding({4, 5})); + + Y->set_output(true); + Index->set_output(true).set_data_type(fe::DataType_t::INT8); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + + auto const status = graph.build_operation_graph(handle); + if (cudnn_frontend::get_backend_version() >= 8600) + REQUIRE(status.is_good()); + else { + REQUIRE(status.is_bad()); + SKIP("Using index tensor is not supported pre 8.6."); + } + REQUIRE(graph.check_support(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(N * H * W * C, false); + Surface Y_gpu(N * H * W * C, false); + Surface Index_gpu(N * H * W * C / 8, false); + std::unordered_map, void*> variant_pack = { + {X, X_gpu.devPtr}, {Y, Y_gpu.devPtr}, {Index, Index_gpu.devPtr}}; + Surface workspace(graph.get_workspace_size(), false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudnnErr(cudnnDestroy(handle)); +} + +TEST_CASE("Resample Avg Pooling", "[resample][pooling][average][graph]") { + namespace fe = cudnn_frontend; + + // This example shows running average pooling graphs. + // There is no difference between NHWC and NCHW support surface. + // See details about support surface in + // https://docs.nvidia.com/deeplearning/cudnn/developer/graph-api.html#resamplefwd + + constexpr int N = 8; + constexpr int H = 56; + constexpr int W = 56; + constexpr int C = 8; + + fe::graph::Graph graph{}; + + graph.set_io_data_type(fe::DataType_t::HALF).set_compute_data_type(fe::DataType_t::FLOAT); + + auto X = graph.tensor(fe::graph::Tensor_attributes().set_dim({N, C, H, W}).set_stride({H * W * C, 1, W * C, C})); + + auto [Y, Index] = graph.resample(X, + fe::graph::Resample_attributes() + .set_is_inference(false) + .set_resampling_mode(fe::ResampleMode_t::AVGPOOL_INCLUDE_PADDING) + .set_padding_mode(fe::PaddingMode_t::ZERO_PAD) + .set_window({2, 3}) + .set_stride({4, 5}) + .set_pre_padding({2, 3}) + .set_post_padding({4, 5})); + + Y->set_output(true); + assert(Index == nullptr); + + cudnnHandle_t handle; + checkCudnnErr(cudnnCreate(&handle)); + + REQUIRE(graph.validate().is_good()); + REQUIRE(graph.build_operation_graph(handle).is_good()); + REQUIRE(graph.check_support(handle).is_good()); + REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); + REQUIRE(graph.build_plans(handle, fe::BuildPlanPolicy_t::HEURISTICS_CHOICE).is_good()); + + Surface X_gpu(N * H * W * C, false); + Surface Y_gpu(N * H * W * C, false); + std::unordered_map, void*> variant_pack = {{X, X_gpu.devPtr}, + {Y, Y_gpu.devPtr}}; + Surface workspace(graph.get_workspace_size(), false); + REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good()); + + checkCudnnErr(cudnnDestroy(handle)); +} \ No newline at end of file diff --git a/samples/utils/error_util.h b/samples/utils/error_util.h index c8abd199..6b9b0d9c 100644 --- a/samples/utils/error_util.h +++ b/samples/utils/error_util.h @@ -28,6 +28,7 @@ #include #include #include +#include #include diff --git a/test/python_fe/test_batchnorm.py b/test/python_fe/test_batchnorm.py index 03733326..b4e34301 100644 --- a/test/python_fe/test_batchnorm.py +++ b/test/python_fe/test_batchnorm.py @@ -1,6 +1,7 @@ import cudnn import pytest import torch +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -8,7 +9,7 @@ class SGBN(torch.nn.Module): def forward(self, input, running_mean, running_var, weight, bias, eps, momentum): return torch.nn.functional.batch_norm(input, running_mean, running_var, weight=weight, bias=bias, training=True, momentum=momentum, eps=eps) -@pytest.mark.skipif(cudnn.backend_version() < 8800, reason="BN with mask output not supported below cudnn 8.8") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.8", reason="BN with mask output not supported below cudnn 8.8") @torch_fork_set_rng(seed=0) def test_bn_relu_with_mask(): @@ -103,7 +104,7 @@ def test_bn_relu_with_mask(): torch.testing.assert_close(inv_var_expected, saved_inv_var_actual, atol=1e-3, rtol=1e-3) # torch.testing.assert_close(mask_expected, mask_actual) -@pytest.mark.skipif(cudnn.backend_version() < 8900, reason="DBN fusions not supported below cudnn 8.9") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9", reason="DBN fusions not supported below cudnn 8.9") @torch_fork_set_rng(seed=0) def test_drelu_dadd_dbn(): @@ -176,7 +177,7 @@ def test_drelu_dadd_dbn(): device_buffers[DX_drelu] = DX_drelu_actual graph.execute(device_buffers, workspace) -@pytest.mark.skipif(cudnn.backend_version() < 8904, reason="BN_infer-Drelu-DBN not supported below cudnn 8.9.4") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.4", reason="BN_infer-Drelu-DBN not supported below cudnn 8.9.4") @torch_fork_set_rng(seed=0) def test_bn_infer_drelu_dbn(): diff --git a/test/python_fe/test_conv_bias.py b/test/python_fe/test_conv_bias.py index 400480e8..2ab6f730 100644 --- a/test/python_fe/test_conv_bias.py +++ b/test/python_fe/test_conv_bias.py @@ -1,6 +1,7 @@ import cudnn import pytest import torch +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -174,7 +175,7 @@ def dleaky_relu(grad: torch.Tensor, mask: torch.Tensor, negative_slope: float): torch.testing.assert_close(Y_expected, Y_actual, atol=1e-4, rtol=1e-4) -@pytest.mark.skipif(cudnn.backend_version() < 8600, reason="requires cudnn 8.6.0 or higher") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.6", reason="requires cudnn 8.6.0 or higher") @torch_fork_set_rng(seed=0) def test_conv_int8(): diff --git a/test/python_fe/test_conv_genstats.py b/test/python_fe/test_conv_genstats.py index c92c4f57..7bda2287 100644 --- a/test/python_fe/test_conv_genstats.py +++ b/test/python_fe/test_conv_genstats.py @@ -1,6 +1,7 @@ import cudnn import pytest import torch +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -21,7 +22,7 @@ def forward(self, scale, bias, x, w, padding = [1,1], stride = [1,1], dilation = stride = [1,1] dilation = [1,1] -@pytest.mark.skipif(cudnn.backend_version() < 8800, reason="requires cudnn 8.8 or higher") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.8", reason="requires cudnn 8.8 or higher") @torch_fork_set_rng(seed=0) def test_conv_genstats(): diff --git a/test/python_fe/test_instancenorm.py b/test/python_fe/test_instancenorm.py index 7029f54e..1ee4aace 100644 --- a/test/python_fe/test_instancenorm.py +++ b/test/python_fe/test_instancenorm.py @@ -2,6 +2,7 @@ import pytest import torch import itertools +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -13,7 +14,7 @@ def param_extract(request): return request.param -@pytest.mark.skipif(cudnn.backend_version() < 8905, reason="IN not supported below cudnn 8.9.5") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.5", reason="IN not supported below cudnn 8.9.5") @torch_fork_set_rng(seed=0) def test_in(param_extract): diff --git a/test/python_fe/test_layernorm.py b/test/python_fe/test_layernorm.py index 9adcd3b3..69956547 100644 --- a/test/python_fe/test_layernorm.py +++ b/test/python_fe/test_layernorm.py @@ -2,6 +2,7 @@ import pytest import torch import itertools +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -14,7 +15,7 @@ def param_extract(request): return request.param -@pytest.mark.skipif(cudnn.backend_version() < 8905, reason="LN not supported below cudnn 8.9.5") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.5", reason="LN not supported below cudnn 8.9.5") @torch_fork_set_rng(seed=0) def test_layernorm(param_extract): diff --git a/test/python_fe/test_matmul_bias_relu.py b/test/python_fe/test_matmul_bias_relu.py index c6bce03f..f5f916cd 100644 --- a/test/python_fe/test_matmul_bias_relu.py +++ b/test/python_fe/test_matmul_bias_relu.py @@ -2,6 +2,7 @@ import itertools import pytest import torch +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -25,7 +26,7 @@ def get_cc(): (major, minor) = torch.cuda.get_device_capability() return major*10 + minor -@pytest.mark.skipif(cudnn.backend_version() < 8906, reason="requires cudnn 8.9.6 or higher") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.6", reason="requires cudnn 8.9.6 or higher") @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch") @torch_fork_set_rng(seed=0) def test_int8_bf16_matmul(): @@ -69,7 +70,7 @@ def test_int8_bf16_matmul(): B_data_type_options = [torch.int8, torch.bfloat16, torch.float16] MMA_data_type_options = [torch.bfloat16, torch.float16, torch.float32] -@pytest.mark.skipif(cudnn.backend_version() < 8906, reason="requires cudnn 8.9.6 or higher") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.6", reason="requires cudnn 8.9.6 or higher") @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires Hopper or newer arch") @pytest.mark.parametrize("A_data_type", A_data_type_options) @pytest.mark.parametrize("B_data_type", B_data_type_options) @@ -106,10 +107,10 @@ def test_mixed_precision_matmul(A_data_type, B_data_type, MMA_data_type): A_casted.set_data_type(convert_to_cudnn_type(MMA_data_type)) # Casting input tensor B is only supported from cudnn v9 - if B_data_type != MMA_data_type and cudnn.backend_version() < 90000: + if B_data_type != MMA_data_type and LooseVersion(cudnn.backend_version_string()) < "9": pytest.skip("mixed precision on B only supported from cudnn v9.") - if cudnn.backend_version() < 90000: + if LooseVersion(cudnn.backend_version_string()) < "9": # Do not create a cast node B_casted = B else: @@ -152,7 +153,7 @@ def test_matmul_bias_relu(param_extract): problem_size_options, input_type = param_extract b, s, e = problem_size_options - if b > 1 and cudnn.backend_version() < 8906: + if b > 1 and LooseVersion(cudnn.backend_version_string()) < "8.9.6": pytest.skip("matmul broadcast only supported 8.9.6 onwards.") # Regression in cudnn backend where ampere does not support matmul broadcast starting 8.9.6 diff --git a/test/python_fe/test_mhas.py b/test/python_fe/test_mhas.py index b25a86f7..f115af3f 100644 --- a/test/python_fe/test_mhas.py +++ b/test/python_fe/test_mhas.py @@ -2,6 +2,7 @@ import pytest import torch import math +from looseversion import LooseVersion import random import os @@ -296,22 +297,24 @@ def test_sdpa(input_type, is_infer, arg_params): - if cudnn.backend_version() < 8903: + cudnn_version = LooseVersion(cudnn.backend_version_string()) + + if cudnn_version < "8.9.3": pytest.skip("SDPA fprop requires cudnn 8.9.3 or higher") - if head_group != "multi_head" and cudnn.backend_version() < 8907: + if head_group != "multi_head" and cudnn_version < "8.9.7": pytest.skip("GQA and MQA is only supported 8.9.7 onwards.") - if is_alibi and cudnn.backend_version() < 8904: + if is_alibi and cudnn_version < "8.9.4": pytest.skip("ALiBi mask is only supported 8.9.4 onwards.") - if is_padding and cudnn.backend_version() < 8903: + if is_padding and cudnn_version < "8.9.3": pytest.skip("Padding mask is only supported 8.9.3 onwards.") - if is_dropout and cudnn.backend_version() < 8906: + if is_dropout and cudnn_version < "8.9.6": pytest.skip("Dropout reference is only supported on 8.9.6 onwards.") - if is_ragged and cudnn.backend_version() < 90000: + if is_ragged and cudnn_version < "9": pytest.skip("Ragged tensor is only supported 9.0.0 onwards") if is_ragged and torch.cuda.get_device_capability()[0] < 9: @@ -358,26 +361,26 @@ def test_sdpa(input_type, h_k = int(arg_params.mha_h_k) if arg_params.mha_h_k != None else h_k h_v = int(arg_params.mha_h_v) if arg_params.mha_h_v != None else h_v - if d_qk != d_v and cudnn.backend_version() < 8906: + if d_qk != d_v and cudnn_version < "8.9.6": pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") - if cudnn.backend_version() < 90000: + if cudnn_version < "9": if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and (is_padding or is_dropout): pytest.skip("s_q not a multiple of 64 with padding/dropout is not supported with cudnn version 9.0.0") - if cudnn.backend_version() < 8906: + if cudnn_version < "8.9.6": pytest.skip("d not a multiple of 64, not-multiple-of-64 seq_kv is not supported below 8.9.6") - if (d_qk % 64 != 0) and cudnn.backend_version() < 8906: + if (d_qk % 64 != 0) and cudnn_version < "8.9.6": pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - if (d_qk % 64 != 0) and cudnn.backend_version() < 8906: + if (d_qk % 64 != 0) and cudnn_version < "8.9.6": pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - if d_qk != d_v and is_ragged: + if d_qk != d_v and is_ragged and cudnn_version < "9.1": pytest.skip("d_qk != d_v is not supported with ragged offset") - print(f"{b=} {s_q=} {s_kv=} {d_qk=} {d_v=} {h_q=} {h_k=} {h_v=}") + print(f"--mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}") attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 @@ -585,13 +588,15 @@ def test_sdpa_backward(input_type, is_ragged, arg_params): - if cudnn.backend_version() < 8903: + cudnn_version = LooseVersion(cudnn.backend_version_string()) + + if cudnn_version < "8.9.3": pytest.skip("SDPA bprop requires cudnn 8.9.3 or higher") - if head_group != "multi_head" and cudnn.backend_version() < 8907: + if head_group != "multi_head" and cudnn_version < "8.9.7": pytest.skip("GQA and MQA is only supported 8.9.7 onwards.") - if is_bias and cudnn.backend_version() < 8906: + if is_bias and cudnn_version < "8.9.6": pytest.skip("dBias is only supported 8.9.6 onwards.") if is_bias and torch.cuda.get_device_capability()[0] < 9: @@ -603,16 +608,16 @@ def test_sdpa_backward(input_type, if is_alibi and not is_causal: pytest.skip("ALiBi mask is only supported with causal mask") - if is_alibi and cudnn.backend_version() < 8904: + if is_alibi and cudnn_version < "8.9.4": pytest.skip("ALiBi mask is only supported 8.9.4 onwards.") - if is_padding and cudnn.backend_version() < 8903: + if is_padding and cudnn_version < "8.9.3": pytest.skip("Padding mask is only supported 8.9.3 onwards.") - if is_dropout and cudnn.backend_version() < 8906: + if is_dropout and cudnn_version < "8.9.6": pytest.skip("RNG dump is only supported on 8.9.6 onwards.") - if is_ragged and cudnn.backend_version() < 90000: + if is_ragged and cudnn_version < "9": pytest.skip("Ragged tensor is only supported 9.0.0 onwards") if is_ragged and torch.cuda.get_device_capability()[0] < 9: @@ -655,10 +660,10 @@ def test_sdpa_backward(input_type, else: assert False, "Head group must be either MHA, GQA, or MQA" - if d_qk != d_v and cudnn.backend_version() < 8906: + if d_qk != d_v and cudnn_version < "8.9.6": pytest.skip("d_qk != d_v is only supported on 8.9.6 onwards.") - if (cudnn.backend_version() < 90000): + if (cudnn_version < "9"): if (s_q < 64): pytest.skip("s_q less than 64 is not supported before cudnn 9.0.0") @@ -668,13 +673,13 @@ def test_sdpa_backward(input_type, if ((s_q % 64 != 0) or (s_kv % 64 != 0)) and is_bias: pytest.skip("cudnn backend does not support bias with non-64-aligned seq_q or seq_kv.") - if (s_kv % 64 != 0) and cudnn.backend_version() < 8906: + if (s_kv % 64 != 0) and cudnn_version < "8.9.6": pytest.skip("not-multiple-of-64 seq_kv is not supported below 8.9.6") - if (d_qk % 64 != 0) and cudnn.backend_version() < 8906: + if (d_qk % 64 != 0) and cudnn_version < "8.9.6": pytest.skip("d not a multiple of 64 is not supported below 8.9.6") - if d_qk != d_v and is_ragged: + if d_qk != d_v and is_ragged and cudnn_version < "9.1": pytest.skip("d_qk != d_v is not supported with ragged offset") # -------------------------- override test parameters if args are provided ---------------- @@ -687,7 +692,7 @@ def test_sdpa_backward(input_type, h_k = int(arg_params.mha_h_k) if arg_params.mha_h_k != None else h_k h_v = int(arg_params.mha_h_v) if arg_params.mha_h_v != None else h_v - print(f"{b=} {s_q=} {s_kv=} {d_qk=} {d_v=} {h_q=} {h_k=} {h_v=}") + print(f"--mha_b={b} --mha_s_q={s_q} --mha_s_kv={s_kv} --mha_d_qk={d_qk} --mha_d_v={d_v} --mha_h_q={h_q} --mha_h_k={h_k} --mha_h_v={h_v}") attn_scale = 0.125 dropout_prob = 0.1 if is_dropout else 0.0 @@ -826,7 +831,7 @@ def test_sdpa_backward(input_type, graph.execute(variant_pack, workspace) torch.cuda.synchronize() - if cudnn.backend_version() < 8906 and is_padding: + if cudnn_version < "8.9.6" and is_padding: # zero out padded region of the output and stats for i, m in enumerate(seq_len_q_gpu): o_gpu[i, :, m:, :] = 0 @@ -1002,8 +1007,8 @@ def test_sdpa_backward(input_type, dBias_ref[i, :, :, n:] = 0 torch.testing.assert_close(dQ_ref, dQ_gpu, check_dtype=False, atol=2e-2, rtol=2e-2) - torch.testing.assert_close(dK_ref, dK_gpu, check_dtype=False, atol=2e-2 if input_type != torch.bfloat16 else 4e-2, rtol=2e-2) - torch.testing.assert_close(dV_ref, dV_gpu, check_dtype=False, atol=2e-2 if input_type != torch.bfloat16 else 4e-2, rtol=2e-2) + torch.testing.assert_close(dK_ref, dK_gpu, check_dtype=False, atol=2e-2 if input_type != torch.bfloat16 else 7e-2, rtol=2e-2) + torch.testing.assert_close(dV_ref, dV_gpu, check_dtype=False, atol=2e-2 if input_type != torch.bfloat16 else 7e-2, rtol=2e-2) if is_bias: torch.testing.assert_close(dBias_ref, dBias_gpu, check_dtype=False, atol=2e-2, rtol=2e-2) diff --git a/test/python_fe/test_rmsnorm.py b/test/python_fe/test_rmsnorm.py index 41cfad4c..ddc52811 100644 --- a/test/python_fe/test_rmsnorm.py +++ b/test/python_fe/test_rmsnorm.py @@ -2,6 +2,7 @@ import pytest import torch import itertools +from looseversion import LooseVersion import torch.nn as nn @@ -39,7 +40,7 @@ def forward(self, x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = No def param_extract(request): return request.param -@pytest.mark.skipif(cudnn.backend_version() < 8906, reason="RmsNorm not supported below cudnn 8.9.6") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.9.6", reason="RmsNorm not supported below cudnn 8.9.6") @torch_fork_set_rng(seed=0) def test_rmsnorm(param_extract): diff --git a/test/python_fe/test_wgrads.py b/test/python_fe/test_wgrads.py index f475b758..64a844d2 100644 --- a/test/python_fe/test_wgrads.py +++ b/test/python_fe/test_wgrads.py @@ -1,6 +1,7 @@ import cudnn import pytest import torch +from looseversion import LooseVersion from test_utils import torch_fork_set_rng @@ -21,47 +22,52 @@ def is_hopper_arch(): stride = [1,1] dilation = [1,1] -@pytest.mark.skipif(cudnn.backend_version() < 8800, reason="requires cudnn 8.8 or higher") +@pytest.mark.skipif(LooseVersion(cudnn.backend_version_string()) < "8.8", reason="requires cudnn 8.8 or higher") @torch_fork_set_rng(seed=0) def test_scale_bias_relu_wgrad(): - if not is_ampere_arch() and not is_hopper_arch(): - pytest.skip("SBR Wgrad is only supported on ampere and hopper.") + try: + if not is_ampere_arch() and not is_hopper_arch(): + pytest.skip("SBR Wgrad is only supported on ampere and hopper.") - # Reference - X_gpu = torch.randn(n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) - DY_gpu = torch.randn(n, k, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) - scale = torch.randn(1, c, 1, 1, device = "cuda", dtype = torch.float16).to(memory_format=torch.channels_last) * 0.01 - bias = torch.randn(1, c, 1, 1, device = "cuda", dtype = torch.float16).to(memory_format=torch.channels_last) * 0.01 - DW_actual = torch.randn(k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) - - graph = cudnn.pygraph(io_data_type = cudnn.data_type.HALF, intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) + # Reference + X_gpu = torch.randn(n, c, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + DY_gpu = torch.randn(n, k, 32, 32, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + scale = torch.randn(1, c, 1, 1, device = "cuda", dtype = torch.float16).to(memory_format=torch.channels_last) * 0.01 + bias = torch.randn(1, c, 1, 1, device = "cuda", dtype = torch.float16).to(memory_format=torch.channels_last) * 0.01 + DW_actual = torch.randn(k, c, 3, 3, requires_grad=False, device="cuda", dtype=torch.float16).to(memory_format=torch.channels_last) + + graph = cudnn.pygraph(io_data_type = cudnn.data_type.HALF, intermediate_data_type = cudnn.data_type.FLOAT, compute_data_type = cudnn.data_type.FLOAT) - X = graph.tensor(name = "X", dim = X_gpu.size(), stride = X_gpu.stride(), data_type = X_gpu.dtype) - DY = graph.tensor(name = "DY", dim = DY_gpu.size(), stride = DY_gpu.stride(), data_type = DY_gpu.dtype) - B = graph.tensor(name = "B", dim = bias.size(), stride = bias.stride(), data_type = bias.dtype) - S = graph.tensor(name = "S", dim = scale.size(), stride = scale.stride(), data_type = scale.dtype) + # X = graph.tensor(name = "X", dim = X_gpu.size(), stride = X_gpu.stride(), data_type = cudnn._compiled_module.data_type.DOUBLE) + X = graph.tensor(name = "X", dim = X_gpu.size(), stride = X_gpu.stride(), data_type = X_gpu.dtype) + DY = graph.tensor(name = "DY", dim = DY_gpu.size(), stride = DY_gpu.stride(), data_type = DY_gpu.dtype) + B = graph.tensor(name = "B", dim = bias.size(), stride = bias.stride(), data_type = bias.dtype) + S = graph.tensor(name = "S", dim = scale.size(), stride = scale.stride(), data_type = scale.dtype) - scale_output = graph.scale(name = "scale", input = X, scale = S) - bias_output = graph.bias(name = "bias", input = scale_output, bias = B) + scale_output = graph.scale(name = "scale", input = X, scale = S) + bias_output = graph.bias(name = "bias", input = scale_output, bias = B) - relu_output = graph.relu(name = "relu", input = bias_output) + relu_output = graph.relu(name = "relu", input = bias_output) - wgrad_output = graph.conv_wgrad(name = "wgrad", image = relu_output, loss = DY, padding = padding, stride = stride, dilation = dilation) - wgrad_output.set_output(True).set_dim([k, c, 3, 3]) - - graph.validate() - graph.build_operation_graph() - graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) - graph.check_support() - graph.build_plans() + wgrad_output = graph.conv_wgrad(name = "wgrad", image = relu_output, loss = DY, padding = padding, stride = stride, dilation = dilation) + wgrad_output.set_output(True).set_dim([k, c, 3, 3]) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.FALLBACK]) + graph.check_support() + graph.build_plans() - workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) + workspace = torch.empty(graph.get_workspace_size(), device="cuda", dtype=torch.uint8) - DW_actual = torch.zeros_like(X_gpu) + DW_actual = torch.zeros_like(X_gpu) - print("Executing test_scale_bias_relu_wgrad") - graph.execute({X: X_gpu, DY: DY_gpu, B: bias, S: scale, wgrad_output: DW_actual}, workspace) + print("Executing test_scale_bias_relu_wgrad") + graph.execute({X: X_gpu, DY: DY_gpu, B: bias, S: scale, wgrad_output: DW_actual}, workspace) + + except cudnn.cudnnGraphNotSupportedError as ex: + print(ex) if __name__ == "__main__": test_scale_bias_relu_wgrad() \ No newline at end of file diff --git a/test/unit_tests/CMakeLists.txt b/test/unit_tests/CMakeLists.txt index bc9e9466..2c58f437 100644 --- a/test/unit_tests/CMakeLists.txt +++ b/test/unit_tests/CMakeLists.txt @@ -22,6 +22,7 @@ add_executable( serialize.cpp validate.cpp + version.cpp ) if (MSVC) @@ -32,7 +33,7 @@ if (MSVC) /wd4458 # local hides class member (currently a problem for all inline setters) /wd4505 # unreferenced function with internal linkage has been removed /wd4101 /wd4189 # unreferenced local - /bigobj + /bigobj # increase number of sections in .Obj file ) else() target_compile_options( @@ -47,6 +48,7 @@ endif() target_link_libraries( unit_tests cudnn_frontend + _cudnn_frontend_pch Catch2::Catch2WithMain CUDNN::cudnn_all diff --git a/test/unit_tests/version.cpp b/test/unit_tests/version.cpp new file mode 100644 index 00000000..9cbc830a --- /dev/null +++ b/test/unit_tests/version.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + * DEALINGS IN THE SOFTWARE. + */ +#include + +#include + +TEST_CASE("version checks", "[version]") { + namespace fe = cudnn_frontend; + + REQUIRE(fe::detail::convert_version_to_str(8907) == "8.9.7"); + REQUIRE(fe::detail::convert_version_to_str(90000) == "9.0.0"); + REQUIRE(fe::detail::convert_version_to_str(90100) == "9.1.0"); + REQUIRE(fe::detail::convert_version_to_str(123456) == "12.34.56"); +} \ No newline at end of file