Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework C API to remove new/delete warnings #14572

Merged
merged 2 commits into from
Feb 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 54 additions & 46 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,40 +824,45 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

std::vector<std::string> feed_names(input_len);
std::vector<OrtValue> feeds(input_len);
InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
std::ostringstream ostr;
ostr << "NULL input supplied for input " << input_names[i];
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str());
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names[i] = input_names[i];
feeds[i] = *reinterpret_cast<const ::OrtValue*>(input[i]);
feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
std::vector<std::string> output_names(output_names_len);
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names[i] = output_names1[i];
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches(output_names_len);
std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
::OrtValue& value = *(output[i]);
fetches[i] = value;
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
Expand All @@ -868,11 +873,24 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
::OrtValue& value = fetches[i];
if (output[i] == nullptr) {
GSL_SUPPRESS(r .11)
output[i] = new OrtValue(value);
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}

assert(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return nullptr;
Expand Down Expand Up @@ -912,8 +930,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateIoBinding, _Inout_ OrtSession* sess, _Outptr_
if (!status.IsOK()) {
return ToOrtStatus(status);
}
GSL_SUPPRESS(r .11)
*out = new OrtIoBinding(std::move(binding));
*out = std::make_unique<OrtIoBinding>(std::move(binding)).release();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -1010,34 +1027,27 @@ ORT_API_STATUS_IMPL(OrtApis::GetBoundOutputValues, _In_ const OrtIoBinding* bind
}

// Used to destroy and de-allocate on exception
size_t created = 0;
IAllocatorUniquePtr<OrtValue*> ortvalues_alloc(reinterpret_cast<OrtValue**>(allocator->Alloc(allocator, outputs.size() * sizeof(OrtValue*))),
[&created, allocator](OrtValue** buffer) {
if (buffer) {
while (created > 0) {
auto p = buffer + --created;
delete (*p);
}
allocator->Free(allocator, buffer);
}
});

[allocator](OrtValue** p) { if (p) allocator->Free(allocator, p);});
if (!ortvalues_alloc) {
return OrtApis::CreateStatus(ORT_FAIL, "Output buffer allocation failed");
}

OrtValue** out_ptr = ortvalues_alloc.get();
InlinedVector<std::unique_ptr<OrtValue>> value_dups;
value_dups.reserve(outputs.size());

for (const auto& out_value : outputs) {
GSL_SUPPRESS(r .11)
*out_ptr = new OrtValue(out_value);
++out_ptr;
++created;
value_dups.push_back(std::make_unique<OrtValue>(out_value));
}

assert(created == outputs.size());

// The rest is noexcept
OrtValue** out_ptr = ortvalues_alloc.get();
for (auto& v : value_dups) {
*out_ptr++ = v.release();
}

*output = ortvalues_alloc.release();
*output_count = created;
*output_count = outputs.size();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -1369,8 +1379,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* ses
auto p = session->GetModelMetadata();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
GSL_SUPPRESS(r .11)
*out = reinterpret_cast<OrtModelMetadata*>(new ModelMetadata(*p.second));
*out = reinterpret_cast<OrtModelMetadata*>(std::make_unique<ModelMetadata>(*p.second).release());
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -2214,12 +2223,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetProfilingStartTimeNs, _In_ const OrtSessi
ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out) {
API_IMPL_BEGIN
GSL_SUPPRESS(r .11)
*out = new OrtArenaCfg();
(*out)->max_mem = max_mem;
(*out)->arena_extend_strategy = arena_extend_strategy;
(*out)->initial_chunk_size_bytes = initial_chunk_size_bytes;
(*out)->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
auto cfg = std::make_unique<OrtArenaCfg>();
cfg->max_mem = max_mem;
cfg->arena_extend_strategy = arena_extend_strategy;
cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
*out = cfg.release();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -2254,9 +2263,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char*
}

// Allow using raw new/delete because this is for C.
GSL_SUPPRESS(r .11)
ORT_API(void, OrtApis::ReleaseArenaCfg, _Frees_ptr_opt_ OrtArenaCfg* ptr) {
delete ptr;
std::unique_ptr<OrtArenaCfg> g(ptr);
}

ORT_API_STATUS_IMPL(OrtApis::CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out) {
Expand Down