diff --git a/src/contrib/nnpack/convolution.cc b/src/contrib/nnpack/convolution.cc index e600360c67f1b..887129819bc2e 100644 --- a/src/contrib/nnpack/convolution.cc +++ b/src/contrib/nnpack/convolution.cc @@ -2,6 +2,7 @@ * Copyright (c) 2017 by Contributors * \file Use external nnpack library call. */ +#include #include #include #include @@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") zero_bias.reset(new std::vector(output->shape[1], 0.0)); } + size_t workspace_size = 0; + nnp_status status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_compute, input_channels, + output_channels, input_size, input_padding, kernel_size, stride_size, + nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); + CHECK_EQ(status, nnp_status_success); + + // Division with rounding up, in case size is not multiple of sizeof(float) + const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float); + + TVMContext ctx = input->ctx; + TVMType type_hint = input->dtype; + + DeviceAPI* cpu_api = DeviceAPI::Get(ctx); + void* workspace_buffer = + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + CHECK(workspace_buffer != nullptr); + for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( algo, nnp_convolution_transform_strategy_compute, input_channels, @@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") static_cast(output->data) + n * output->shape[1] * output->shape[2] * output->shape[3], - NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL); + workspace_buffer, &workspace_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); } + cpu_api->FreeWorkspace(ctx, workspace_buffer); }); TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") @@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra zero_bias.reset(new std::vector(output->shape[1], 0.0)); } + size_t workspace_size = 0; + nnp_status status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_reuse, input_channels, + output_channels, input_size, input_padding, kernel_size, stride_size, + nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); + CHECK_EQ(status, nnp_status_success); + + // Division with rounding up, in case size is not multiple of sizeof(float) + const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float); + + TVMContext ctx = input->ctx; + TVMType type_hint = input->dtype; + + DeviceAPI* cpu_api = DeviceAPI::Get(ctx); + void* workspace_buffer = + cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint); + CHECK(workspace_buffer != nullptr); + for (auto n = 0; n < input->shape[0]; ++n) { nnp_status status = nnp_convolution_inference( algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, @@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra static_cast(output->data) + n * output->shape[1] * output->shape[2] * output->shape[3], - NULL, NULL, - nnp_activation_identity, NULL, entry->threadpool, NULL); + workspace_buffer, &workspace_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); CHECK_EQ(status, nnp_status_success); } + + cpu_api->FreeWorkspace(ctx, workspace_buffer); }); TVM_REGISTER_GLOBAL(