Skip to content

Commit

Permalink
[nnpack] Preallocate workspace buffer (apache#2369)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlu1 authored and Wei Chen committed Feb 20, 2019
1 parent 844218d commit d0a9093
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions src/contrib/nnpack/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file Use external nnpack library call.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/util.h>
#include <dmlc/logging.h>
Expand Down Expand Up @@ -72,6 +73,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}

size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);

// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);

TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;

DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);

for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_compute, input_channels,
Expand All @@ -85,10 +105,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL);
workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);

CHECK_EQ(status, nnp_status_success);
}
cpu_api->FreeWorkspace(ctx, workspace_buffer);
});

TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
Expand Down Expand Up @@ -147,6 +169,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
}

size_t workspace_size = 0;
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels,
output_channels, input_size, input_padding, kernel_size, stride_size,
nullptr, nullptr, nullptr, nullptr, nullptr, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);

// Division with rounding up, in case size is not multiple of sizeof(float)
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);

TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype;

DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer =
cpu_api->AllocWorkspace(ctx, workspace_elements * sizeof(float), type_hint);
CHECK(workspace_buffer != nullptr);

for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
Expand All @@ -159,10 +200,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] *
output->shape[3],
NULL, NULL,
nnp_activation_identity, NULL, entry->threadpool, NULL);
workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success);
}

cpu_api->FreeWorkspace(ctx, workspace_buffer);
});

TVM_REGISTER_GLOBAL(
Expand Down

0 comments on commit d0a9093

Please sign in to comment.