@@ -20,13 +20,15 @@ class KernelDefBuilder;
2020typedef std::map<size_t , OrtMemType> MemTypeMap;
2121
2222// note that input/output might be on CPU implicitly when the node is from CPU execution provider
23- inline bool MemTypeOnCpuExplicitly (const MemTypeMap& mem_type_map, size_t index) {
24- auto iter = mem_type_map.find (index);
25- return iter != mem_type_map.end () && (iter->second == OrtMemTypeCPUInput || iter->second == OrtMemTypeCPUOutput);
23+ inline bool MemTypeOnCpuExplicitly (OrtMemType mem_type) {
24+ return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
2625}
2726
2827class KernelDef {
2928 public:
29+ explicit KernelDef () : default_inputs_mem_type_(OrtMemTypeDefault), default_outputs_mem_type_(OrtMemTypeDefault) {
30+ }
31+
3032 const std::string& OpName () const {
3133 return op_name_;
3234 }
@@ -56,17 +58,20 @@ class KernelDef {
5658 return alias_map_;
5759 }
5860
59- const MemTypeMap& InputMemoryType () const {
60- return input_memory_type_args_;
61- }
62-
63- const MemTypeMap& OutputMemoryType () const {
64- return output_memory_type_args_ ;
61+ OrtMemType InputMemoryType (size_t input_index ) const {
62+ auto it = input_memory_type_args_. find (input_index) ;
63+ if (it == input_memory_type_args_. end ())
64+ return default_inputs_mem_type_;
65+ else
66+ return it-> second ;
6567 }
6668
67- // legacy interface for winml, should not be used in onnxruntime
68- const MemTypeMap& MemoryType () const {
69- return output_memory_type_args_;
69+ OrtMemType OutputMemoryType (size_t output_index) const {
70+ auto it = output_memory_type_args_.find (output_index);
71+ if (it == output_memory_type_args_.end ())
72+ return default_outputs_mem_type_;
73+ else
74+ return it->second ;
7075 }
7176
7277 int ExecQueueId () const {
@@ -111,6 +116,10 @@ class KernelDef {
111116
112117 // execution command queue id, 0 for default queue in execution provider
113118 int exec_queue_id_ = 0 ;
119+ // Default memory type for all inputs
120+ OrtMemType default_inputs_mem_type_;
121+ // Default memory type for all outputs
122+ OrtMemType default_outputs_mem_type_;
114123};
115124
116125class KernelDefBuilder {
@@ -212,6 +221,22 @@ class KernelDefBuilder {
212221 return *this ;
213222 }
214223
224+ /* *
225+ Specify the default inputs memory type, if not specified, it is DefaultMemory
226+ */
227+ KernelDefBuilder& SetDefaultInputsMemoryType (OrtMemType mem_type) {
228+ kernel_def_->default_inputs_mem_type_ = mem_type;
229+ return *this ;
230+ }
231+
232+ /* *
233+ Specify the default outputs memory type, if not specified, it is DefaultMemory
234+ */
235+ KernelDefBuilder& SetDefaultOutputMemoryType (OrtMemType mem_type) {
236+ kernel_def_->default_outputs_mem_type_ = mem_type;
237+ return *this ;
238+ }
239+
215240 /* *
216241 Return the kernel definition, passing ownership of the KernelDef to the caller
217242 */
0 commit comments