Skip to content

Commit c453b48

Browse files
authored
update kernel memory type interface (#225)
* refactor the kernel memory type interface * remove useless change * fix comments in PR
1 parent a43382e commit c453b48

File tree

5 files changed

+59
-37
lines changed

5 files changed

+59
-37
lines changed

include/onnxruntime/core/framework/kernel_def_builder.h

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ class KernelDefBuilder;
2020
typedef std::map<size_t, OrtMemType> MemTypeMap;
2121

2222
// note that input/output might be on CPU implicitly when the node is from CPU execution provider
23-
inline bool MemTypeOnCpuExplicitly(const MemTypeMap& mem_type_map, size_t index) {
24-
auto iter = mem_type_map.find(index);
25-
return iter != mem_type_map.end() && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput);
23+
inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
24+
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
2625
}
2726

2827
class KernelDef {
2928
public:
29+
explicit KernelDef() : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) {
30+
}
31+
3032
const std::string& OpName() const {
3133
return op_name_;
3234
}
@@ -56,17 +58,20 @@ class KernelDef {
5658
return alias_map_;
5759
}
5860

59-
const MemTypeMap& InputMemoryType() const {
60-
return input_memory_type_args_;
61-
}
62-
63-
const MemTypeMap& OutputMemoryType() const {
64-
return output_memory_type_args_;
61+
OrtMemType InputMemoryType(size_t input_index) const {
62+
auto it = input_memory_type_args_.find(input_index);
63+
if (it == input_memory_type_args_.end())
64+
return default_inputs_mem_type_;
65+
else
66+
return it->second;
6567
}
6668

67-
// legacy interface for winml, should not be used in onnxruntime
68-
const MemTypeMap& MemoryType() const {
69-
return output_memory_type_args_;
69+
OrtMemType OutputMemoryType(size_t output_index) const {
70+
auto it = output_memory_type_args_.find(output_index);
71+
if (it == output_memory_type_args_.end())
72+
return default_outputs_mem_type_;
73+
else
74+
return it->second;
7075
}
7176

7277
int ExecQueueId() const {
@@ -111,6 +116,10 @@ class KernelDef {
111116

112117
// execution command queue id, 0 for default queue in execution provider
113118
int exec_queue_id_ = 0;
119+
// Default memory type for all inputs
120+
OrtMemType default_inputs_mem_type_;
121+
// Default memory type for all outputs
122+
OrtMemType default_outputs_mem_type_;
114123
};
115124

116125
class KernelDefBuilder {
@@ -212,6 +221,22 @@ class KernelDefBuilder {
212221
return *this;
213222
}
214223

224+
/**
225+
Specify the default inputs memory type, if not specified, it is DefaultMemory
226+
*/
227+
KernelDefBuilder& SetDefaultInputsMemoryType(OrtMemType mem_type) {
228+
kernel_def_->default_inputs_mem_type_ = mem_type;
229+
return *this;
230+
}
231+
232+
/**
233+
Specify the default outputs memory type, if not specified, it is DefaultMemory
234+
*/
235+
KernelDefBuilder& SetDefaultOutputMemoryType(OrtMemType mem_type) {
236+
kernel_def_->default_outputs_mem_type_ = mem_type;
237+
return *this;
238+
}
239+
215240
/**
216241
Return the kernel definition, passing ownership of the KernelDef to the caller
217242
*/

onnxruntime/core/framework/allocation_planner.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,6 @@ class PlannerImpl {
380380
ORT_ENFORCE(exec_provider);
381381

382382
auto& default_allocator_info = exec_provider->GetAllocator(0, OrtMemTypeDefault)->Info();
383-
auto& mem_type_allocated_args = p_kernelDef->OutputMemoryType();
384383
auto& outputs = pnode->OutputDefs();
385384
auto num_outputs = outputs.size();
386385

@@ -393,11 +392,11 @@ class PlannerImpl {
393392
if (strcmp(default_allocator_info.name, CPU) != 0) {
394393
// By default, outputs of this node are allocated on the default device allocator,
395394
// except for outputs marked for allocation in MemoryType:
396-
auto memory_type_iter = mem_type_allocated_args.find(i);
397-
if (memory_type_iter == mem_type_allocated_args.end()) {
395+
auto memory_type = p_kernelDef->OutputMemoryType(i);
396+
if (memory_type == OrtMemTypeDefault) {
398397
AllocPlan(index).location = default_allocator_info;
399398
} else {
400-
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type_iter->second)->Info();
399+
AllocPlan(index).location = exec_provider->GetAllocator(0, memory_type)->Info();
401400
}
402401
}
403402
}
@@ -438,7 +437,7 @@ class PlannerImpl {
438437

439438
thisplan.alloc_kind = AllocKind::kAllocateStatically;
440439
auto p_opkernelDef = utils::GetKernelDef(kernel_registry_, node);
441-
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(), index))
440+
if (MemTypeOnCpuExplicitly(p_opkernelDef->InputMemoryType(index)))
442441
// weights are not output from any node, so it's OK to put its location on CPU provider
443442
thisplan.location = execution_providers_.Get(onnxruntime::kCpuExecutionProvider)->GetAllocator(0, OrtMemTypeDefault)->Info();
444443
else

onnxruntime/core/framework/kernel_def_builder.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,20 +66,20 @@ bool KernelDef::IsConflict(const KernelDef& other) const {
6666
return false;
6767

6868
//check memory type
69-
auto other_input_mem_types = other.InputMemoryType();
69+
auto& other_input_mem_types = other.input_memory_type_args_;
7070
for (auto it : input_memory_type_args_) {
71-
if (other_input_mem_types.count(it.first) && other_input_mem_types[it.first] == it.second)
71+
if (other_input_mem_types.count(it.first) && other_input_mem_types.find(it.first)->second == it.second)
7272
return false;
7373
}
74-
if (input_memory_type_args_.empty() && !other.InputMemoryType().empty())
74+
if (input_memory_type_args_.empty() && !other.input_memory_type_args_.empty())
7575
return false;
7676

77-
auto other_output_mem_types = other.OutputMemoryType();
77+
auto& other_output_mem_types = other.output_memory_type_args_;
7878
for (auto it : output_memory_type_args_) {
79-
if (other_output_mem_types.count(it.first) && other_output_mem_types[it.first] == it.second)
79+
if (other_output_mem_types.count(it.first) && other_output_mem_types.find(it.second)->second == it.second)
8080
return false;
8181
}
82-
return !(output_memory_type_args_.empty() && !other.OutputMemoryType().empty());
82+
return !(output_memory_type_args_.empty() && !other.output_memory_type_args_.empty());
8383
}
8484

8585
KernelDefBuilder& KernelDefBuilder::SetName(const std::string& op_name) {

onnxruntime/core/framework/transformer_memcpy.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,25 +68,24 @@ void TransformerMemcpyImpl::ProcessDefs(onnxruntime::Node& node, const KernelReg
6868
// note KernelCreateInfo might be nullptr for custom kernel
6969
const KernelCreateInfo* kci = nullptr;
7070
kernel_registries.SearchKernelRegistry(node, &kci);
71-
const auto* input_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
72-
const auto* output_mem_types = kci ? &kci->kernel_def->InputMemoryType() : nullptr;
71+
7372
ORT_ENFORCE(onnxruntime::Node::ForEachWithIndex(
74-
node.InputDefs(),
75-
[this, &input_mem_types](const onnxruntime::NodeArg& arg, size_t index) {
76-
if (input_mem_types && MemTypeOnCpuExplicitly(*input_mem_types, index))
77-
non_provider_input_defs_.insert(&arg);
78-
else
79-
provider_input_defs_.insert(&arg);
80-
return Status::OK();
81-
})
82-
.IsOK());
73+
node.InputDefs(),
74+
[this, &kci](const onnxruntime::NodeArg& arg, size_t index) {
75+
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index)))
76+
non_provider_input_defs_.insert(&arg);
77+
else
78+
provider_input_defs_.insert(&arg);
79+
return Status::OK();
80+
})
81+
.IsOK());
8382
auto& output_defs = node.MutableOutputDefs();
8483
for (size_t i = 0; i < output_defs.size(); ++i) {
8584
auto arg = output_defs[i];
8685
if (!arg->Exists())
8786
continue;
8887

89-
if (output_mem_types && MemTypeOnCpuExplicitly(*output_mem_types, i))
88+
if (kci && MemTypeOnCpuExplicitly(kci->kernel_def->OutputMemoryType(i)))
9089
non_provider_output_defs_.insert(arg);
9190
else
9291
provider_output_defs_.insert(arg);

onnxruntime/core/session/IOBinding.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,9 @@ common::Status IOBinding::CopyOneInputAcrossDevices(const SessionState& session_
6060
size_t index = node_info.index;
6161
auto& node = *node_info.p_node;
6262
const KernelCreateInfo* kci = node_info.kci;
63-
const auto* node_input_mem_types = (kci != nullptr) ? &kci->kernel_def->InputMemoryType() : nullptr;
6463

6564
// node may declare input_mem_type to be on CPU explicitly
66-
bool node_input_on_cpu = node_input_mem_types && MemTypeOnCpuExplicitly(*node_input_mem_types, index);
65+
bool node_input_on_cpu = kci && MemTypeOnCpuExplicitly(kci->kernel_def->InputMemoryType(index));
6766
auto& required_provider_type = node_input_on_cpu ? onnxruntime::kCpuExecutionProvider : node.GetExecutionProviderType();
6867
if (!orig_mlvalue.IsTensor()) {
6968
// copying not supported for non-tensor types

0 commit comments

Comments
 (0)