Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash in "SGBN Add Relu Graph" Test Case When All Tensor Dimensions Are the Same #122

Open
bokutotu opened this issue Dec 14, 2024 · 3 comments

Comments

@bokutotu
Copy link

bokutotu commented Dec 14, 2024

Describe the bug

When all dimensions of the tensor are the same number in the "SGBN Add Relu Graph" test case, the program crashes. This issue occurs during the batch normalization process, which should be a valid scenario.

Expected behavior
The program should handle tensors with all dimensions being the same number without crashing, and the batch normalization process should complete successfully.

System Environment (please complete the following information):

  • cudnn_frontend version: [v1.8.0]
  • cudnn_backend version: [v9.5.1]
  • GPU arch: [RTX 4090]
  • cuda runtime version: [12.3]
  • cuda driver version: [565.72]
  • host compiler: [e.g. clang14]
  • OS: [e.g. ubuntu22.04 (wsl2)]

API logs
Please attach API logs for both cudnn_frontend and cudnn_backend.
frontend

To Reproduce
Steps to reproduce the behavior:

  1. Set up the "SGBN Add Relu Graph" test case with a tensor where all dimensions are the same number.
  2. Run the test case.
  3. Observe the crash during the batch normalization process.
TEST_CASE("SGBN Add Relu Graph", "[batchnorm][graph]") {
    namespace fe = cudnn_frontend;
    fe::graph::Graph graph;
    graph.set_io_data_type(fe::DataType_t::HALF)
        .set_intermediate_data_type(fe::DataType_t::FLOAT)
        .set_compute_data_type(fe::DataType_t::FLOAT);

    bool has_running_stats = true;
    auto X                 = graph.tensor(fe::graph::Tensor_attributes()
                              .set_name("X")
                              .set_dim({2, 2, 2, 2})
                              .set_stride({2 * 2 * 2, 2*2, 2 , 1}));
    auto prev_running_mean = graph.tensor(fe::graph::Tensor_attributes()
                                              .set_name("prev_running_mean")
                                              .set_dim({1, 2, 1, 1})
                                              .set_stride({2, 1, 2, 2})
                                              .set_data_type(fe::DataType_t::FLOAT));
    auto prev_running_var  = graph.tensor(fe::graph::Tensor_attributes()
                                             .set_name("prev_running_var")
                                             .set_dim({1, 2, 1, 1})
                                             .set_stride({2, 1, 2, 2})
                                             .set_data_type(fe::DataType_t::FLOAT));
    auto scale             = graph.tensor(fe::graph::Tensor_attributes()
                                  .set_name("scale")
                                  .set_dim({1, 2, 1, 1})
                                  .set_stride({2, 1, 2, 2})
                                  .set_data_type(fe::DataType_t::FLOAT));
    auto bias              = graph.tensor(fe::graph::Tensor_attributes()
                                 .set_name("bias")
                                 .set_dim({1, 2, 1, 1})
                                 .set_stride({2, 1, 2, 2})
                                 .set_data_type(fe::DataType_t::FLOAT));

    auto peer_stats_0 = graph.tensor(fe::graph::Tensor_attributes()
                                         .set_dim({2, 4 * 2, 1, 1})
                                         .set_stride({4 * 2, 1, 4 * 2, 4 * 2})
                                         .set_data_type(fe::DataType_t::FLOAT));
    auto peer_stats_1 = graph.tensor(fe::graph::Tensor_attributes()
                                         .set_dim({2, 4 * 2, 1, 1})
                                         .set_stride({4 * 2, 1, 4 * 2, 4 * 2})
                                         .set_data_type(fe::DataType_t::FLOAT));

    auto epsilon  = graph.tensor(1e-05f);
    auto momentum = graph.tensor(1e-01f);

    auto batchnorm_options =
        fe::graph::Batchnorm_attributes().set_epsilon(epsilon).set_peer_stats({peer_stats_0, peer_stats_1});
    if (has_running_stats) {
        batchnorm_options.set_previous_running_stats(prev_running_mean, prev_running_var, momentum);
    }

    auto [bn_output, mean, inv_variance, next_running_mean, next_running_var] =
        graph.batchnorm(X, scale, bias, batchnorm_options);
    mean->set_output(true).set_data_type(fe::DataType_t::FLOAT);
    inv_variance->set_output(true).set_data_type(fe::DataType_t::FLOAT);

    if (has_running_stats) {
        next_running_mean->set_output(true).set_data_type(fe::DataType_t::FLOAT);
    }
    if (has_running_stats) {
        next_running_var->set_output(true).set_data_type(fe::DataType_t::FLOAT);
    }

    auto A           = graph.tensor(fe::graph::Tensor_attributes()
                              .set_name("A")
                              .set_dim({2, 2, 2, 2})
                              .set_stride({2 * 2 * 2, 2*2, 2, 1})
                              .set_data_type(fe::DataType_t::HALF));
    auto add_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::ADD);
    auto add_output  = graph.pointwise(bn_output, A, add_options);

    auto relu_options = fe::graph::Pointwise_attributes().set_mode(fe::PointwiseMode_t::RELU_FWD);
    auto Y            = graph.pointwise(add_output, relu_options);
    Y->set_output(true);

#if (CUDNN_VERSION < 8700)
    SKIP("single GPU BN is not supported in cudnn versions prior to 8.7");
#endif
    if (check_device_arch_newer_than("ampere") == false) {
        SKIP("ConvBNFprop requires Ampere and up");
    }
    cudnnHandle_t handle;
    CUDNN_CHECK(cudnnCreate(&handle));

    REQUIRE(graph.validate().is_good());

    REQUIRE(graph.build_operation_graph(handle).is_good());

    REQUIRE(graph.create_execution_plans({fe::HeurMode_t::FALLBACK}).is_good());

    REQUIRE(graph.check_support(handle).is_good());

    REQUIRE(graph.build_plans(handle).is_good());

    Surface<half> X_tensor(4 * 32 * 16 * 16, false);
    Surface<float> Mean_tensor(32, false);
    Surface<float> Var_tensor(32, false);
    Surface<float> Previous_running_mean_tensor(32, false);
    Surface<float> Previous_running_var_tensor(32, false);
    Surface<float> Next_running_mean_tensor(32, false);
    Surface<float> Next_running_var_tensor(32, false);
    Surface<float> Scale_tensor(32, false);
    Surface<float> Bias_tensor(32, false);
    Surface<half> A_tensor(4 * 32 * 16 * 16, false);
    Surface<half> Y_tensor(4 * 32 * 16 * 16, false);
    Surface<float> Peer_stats_0_tensor(2 * 4 * 32, false, true);
    Surface<float> Peer_stats_1_tensor(2 * 4 * 32, false);

    int64_t workspace_size;
    REQUIRE(graph.get_workspace_size(workspace_size).is_good());
    Surface<int8_t> workspace(workspace_size, false);

    std::unordered_map<std::shared_ptr<fe::graph::Tensor_attributes>, void*> variant_pack = {
        {X, X_tensor.devPtr},
        {mean, Mean_tensor.devPtr},
        {inv_variance, Var_tensor.devPtr},
        {scale, Scale_tensor.devPtr},
        {bias, Bias_tensor.devPtr},
        {A, A_tensor.devPtr},
        {Y, Y_tensor.devPtr},
        {peer_stats_0, Peer_stats_0_tensor.devPtr},
        {peer_stats_1, Peer_stats_1_tensor.devPtr}};

    if (has_running_stats) {
        variant_pack[prev_running_mean] = Previous_running_mean_tensor.devPtr;
        variant_pack[prev_running_var]  = Previous_running_var_tensor.devPtr;
        variant_pack[next_running_mean] = Next_running_mean_tensor.devPtr;
        variant_pack[next_running_var]  = Next_running_var_tensor.devPtr;
    }
    REQUIRE(graph.execute(handle, variant_pack, workspace.devPtr).is_good());

    cudnnDestroy(handle);
}
-------------------------------------------------------------------------------
SGBN Add Relu Graph
-------------------------------------------------------------------------------
/home/bokutotu/cudnn-frontend/samples/cpp/norm/batchnorm.cpp:121
...............................................................................

/home/bokutotu/cudnn-frontend/samples/cpp/norm/batchnorm.cpp:209: FAILED:
  REQUIRE( graph.create_execution_plans({fe::HeurMode_t::FALLBACK}).is_good() )
with expansion:
  false

===============================================================================
test cases: 1 | 1 failed
assertions: 3 | 2 passed | 1 failed

Additional context
This issue seems to be related to how cudnn-frontend handles tensors with identical dimensions. It might be a bug in the cudnn-frontend library.

@bokutotu
Copy link
Author

bokutotu commented Dec 14, 2024

I checked CUDNN_FRONTEND_LOG_FLIE with git grep and this string is not present anywhere in the code. What are you proceeding to set this environment variable for?

https://github.com/NVIDIA/cudnn-frontend/blob/main/.github/ISSUE_TEMPLATE/bug_report.md?plain=1#L29

Where can I look for documentation that matches the reality of the code?

@Anerudhan
Copy link
Collaborator

Anerudhan commented Dec 16, 2024

Hi @bokutotu ,

Thanks for the question. In this sample, the channel count should be a multiple of 8 for half precision input. Can, you please check a dim like 16,16,16,16 works?

Please see documentation here

For FP16 and BF16 data types, the channel count C for the tensors must be a multiple of 8 while for float data type the channel count must be a multiple of 4.

For your question

What are you proceeding to set this environment variable for?

This is used to set the file for frontend log file. I noticed there is a typo CUDNN_FRONTEND_LOG_FLIE should be CUDNN_FRONTEND_LOG_FILE. Thanks for pointing it out

-Anerudhan

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants