Skip to content

Commit

Permalink
Rework C API to remove new/delete warnings (#14572)
Browse files Browse the repository at this point in the history
### Description
Re-work code so it does not require GSL_SUPPRESS

### Motivation and Context
Do things right.
  • Loading branch information
yuslepukhin authored Feb 8, 2023
1 parent 10ab252 commit 767619c
Showing 1 changed file with 54 additions and 46 deletions.
100 changes: 54 additions & 46 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -824,40 +824,45 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu
API_IMPL_BEGIN
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);

std::vector<std::string> feed_names(input_len);
std::vector<OrtValue> feeds(input_len);
InlinedVector<std::string> feed_names;
feed_names.reserve(input_len);
InlinedVector<OrtValue> feeds;
feeds.reserve(input_len);

for (size_t i = 0; i != input_len; ++i) {
if (input_names[i] == nullptr || input_names[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "input name cannot be empty");
}

if (!input[i]) {
std::ostringstream ostr;
ostr << "NULL input supplied for input " << input_names[i];
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str());
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT,
MakeString("NULL input supplied for input ", input_names[i]).c_str());
}

feed_names[i] = input_names[i];
feeds[i] = *reinterpret_cast<const ::OrtValue*>(input[i]);
feed_names.emplace_back(input_names[i]);
feeds.emplace_back(*input[i]);
}

// Create output feed
std::vector<std::string> output_names(output_names_len);
InlinedVector<std::string> output_names;
output_names.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output_names1[i] == nullptr || output_names1[i][0] == '\0') {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "output name cannot be empty");
}
output_names[i] = output_names1[i];
output_names.emplace_back(output_names1[i]);
}

std::vector<OrtValue> fetches(output_names_len);
std::vector<OrtValue> fetches;
fetches.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] != nullptr) {
::OrtValue& value = *(output[i]);
fetches[i] = value;
fetches.emplace_back(*output[i]);
} else {
fetches.emplace_back();
}
}

Status status;
if (run_options == nullptr) {
OrtRunOptions op;
Expand All @@ -868,11 +873,24 @@ ORT_API_STATUS_IMPL(OrtApis::Run, _Inout_ OrtSession* sess, _In_opt_ const OrtRu

if (!status.IsOK())
return ToOrtStatus(status);

// We do it in two loops to make sure copy __ctors does not throw
InlinedVector<std::unique_ptr<OrtValue>> output_unique_ptrs;
output_unique_ptrs.reserve(output_names_len);
for (size_t i = 0; i != output_names_len; ++i) {
::OrtValue& value = fetches[i];
if (output[i] == nullptr) {
GSL_SUPPRESS(r .11)
output[i] = new OrtValue(value);
output_unique_ptrs.emplace_back(std::make_unique<OrtValue>(fetches[i]));
} else {
output_unique_ptrs.emplace_back();
}
}

assert(output_unique_ptrs.size() == output_names_len);

for (size_t i = 0; i != output_names_len; ++i) {
if (output[i] == nullptr) {
assert(output_unique_ptrs[i] != nullptr);
output[i] = output_unique_ptrs[i].release();
}
}
return nullptr;
Expand Down Expand Up @@ -912,8 +930,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateIoBinding, _Inout_ OrtSession* sess, _Outptr_
if (!status.IsOK()) {
return ToOrtStatus(status);
}
GSL_SUPPRESS(r .11)
*out = new OrtIoBinding(std::move(binding));
*out = std::make_unique<OrtIoBinding>(std::move(binding)).release();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -1010,34 +1027,27 @@ ORT_API_STATUS_IMPL(OrtApis::GetBoundOutputValues, _In_ const OrtIoBinding* bind
}

// Used to destroy and de-allocate on exception
size_t created = 0;
IAllocatorUniquePtr<OrtValue*> ortvalues_alloc(reinterpret_cast<OrtValue**>(allocator->Alloc(allocator, outputs.size() * sizeof(OrtValue*))),
[&created, allocator](OrtValue** buffer) {
if (buffer) {
while (created > 0) {
auto p = buffer + --created;
delete (*p);
}
allocator->Free(allocator, buffer);
}
});

[allocator](OrtValue** p) { if (p) allocator->Free(allocator, p);});
if (!ortvalues_alloc) {
return OrtApis::CreateStatus(ORT_FAIL, "Output buffer allocation failed");
}

OrtValue** out_ptr = ortvalues_alloc.get();
InlinedVector<std::unique_ptr<OrtValue>> value_dups;
value_dups.reserve(outputs.size());

for (const auto& out_value : outputs) {
GSL_SUPPRESS(r .11)
*out_ptr = new OrtValue(out_value);
++out_ptr;
++created;
value_dups.push_back(std::make_unique<OrtValue>(out_value));
}

assert(created == outputs.size());

// The rest is noexcept
OrtValue** out_ptr = ortvalues_alloc.get();
for (auto& v : value_dups) {
*out_ptr++ = v.release();
}

*output = ortvalues_alloc.release();
*output_count = created;
*output_count = outputs.size();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -1369,8 +1379,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetModelMetadata, _In_ const OrtSession* ses
auto p = session->GetModelMetadata();
if (!p.first.IsOK())
return ToOrtStatus(p.first);
GSL_SUPPRESS(r .11)
*out = reinterpret_cast<OrtModelMetadata*>(new ModelMetadata(*p.second));
*out = reinterpret_cast<OrtModelMetadata*>(std::make_unique<ModelMetadata>(*p.second).release());
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -2214,12 +2223,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionGetProfilingStartTimeNs, _In_ const OrtSessi
ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes,
int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out) {
API_IMPL_BEGIN
GSL_SUPPRESS(r .11)
*out = new OrtArenaCfg();
(*out)->max_mem = max_mem;
(*out)->arena_extend_strategy = arena_extend_strategy;
(*out)->initial_chunk_size_bytes = initial_chunk_size_bytes;
(*out)->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
auto cfg = std::make_unique<OrtArenaCfg>();
cfg->max_mem = max_mem;
cfg->arena_extend_strategy = arena_extend_strategy;
cfg->initial_chunk_size_bytes = initial_chunk_size_bytes;
cfg->max_dead_bytes_per_chunk = max_dead_bytes_per_chunk;
*out = cfg.release();
return nullptr;
API_IMPL_END
}
Expand Down Expand Up @@ -2254,9 +2263,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateArenaCfgV2, _In_reads_(num_keys) const char*
}

// Allow using raw new/delete because this is for C.
GSL_SUPPRESS(r .11)
ORT_API(void, OrtApis::ReleaseArenaCfg, _Frees_ptr_opt_ OrtArenaCfg* ptr) {
delete ptr;
std::unique_ptr<OrtArenaCfg> g(ptr);
}

ORT_API_STATUS_IMPL(OrtApis::CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out) {
Expand Down

0 comments on commit 767619c

Please sign in to comment.