2626 * from 4. to 6. is the Inference Loop
2727 */
2828
29+ /* these limits are arbitrary. */
30+ #define MAX_GRAPHS 4
31+ #define MAX_EXECUTION_CONTEXTS 4
32+
2933typedef struct {
3034 ov_core_t * core ;
3135 /* keep input model files */
32- void * weight_data ;
33- ov_tensor_t * weights_tensor ;
34- ov_model_t * model ;
35- ov_compiled_model_t * compiled_model ;
36- ov_infer_request_t * infer_request ;
37- ov_tensor_t * input_tensor ;
36+ struct OpenVINOGraph {
37+ void * weight_data ;
38+ ov_tensor_t * weights_tensor ;
39+ ov_model_t * model ;
40+ ov_compiled_model_t * compiled_model ;
41+ } graphs [MAX_GRAPHS ];
42+ struct OpenVINOExecutionContext {
43+ struct OpenVINOGraph * graph ;
44+ ov_infer_request_t * infer_request ;
45+ } execution_contexts [MAX_EXECUTION_CONTEXTS ];
46+ unsigned int n_graphs ;
47+ unsigned int n_execution_contexts ;
3848} OpenVINOContext ;
3949
4050/*
@@ -179,6 +189,29 @@ wasi_nn_tensor_type_to_openvino_element_type(tensor_type wasi_nn_type)
179189 return UNDEFINED ;
180190}
181191
192+ static void
193+ free_graph (struct OpenVINOGraph * graph )
194+ {
195+ if (graph -> weight_data )
196+ os_free (graph -> weight_data );
197+
198+ if (graph -> weights_tensor )
199+ ov_tensor_free (graph -> weights_tensor );
200+
201+ if (graph -> model )
202+ ov_model_free (graph -> model );
203+
204+ if (graph -> compiled_model )
205+ ov_compiled_model_free (graph -> compiled_model );
206+ }
207+
208+ static void
209+ free_execution_context (struct OpenVINOExecutionContext * c )
210+ {
211+ if (c -> infer_request )
212+ ov_infer_request_free (c -> infer_request );
213+ }
214+
182215static wasi_nn_error
183216uint32_array_to_int64_array (uint32_t array_size , uint32_t * src , int64_t * * dst )
184217{
@@ -198,6 +231,8 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding,
198231 execution_target target , graph * g )
199232{
200233 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
234+ struct OpenVINOGraph * graph ;
235+ unsigned int graph_idx ;
201236 wasi_nn_error ret = unsupported_operation ;
202237
203238 if (encoding != openvino ) {
@@ -223,65 +258,127 @@ load(void *ctx, graph_builder_array *builder, graph_encoding encoding,
223258 graph_builder xml = builder -> buf [0 ];
224259 graph_builder weight = builder -> buf [1 ];
225260
261+ graph_idx = ov_ctx -> n_graphs ;
262+ if (graph_idx >= MAX_GRAPHS ) {
263+ return runtime_error ;
264+ }
265+ graph = & ov_ctx -> graphs [graph_idx ];
266+ memset (graph , 0 , sizeof (* graph ));
267+
226268 /* transfer weight to an ov tensor */
227269 {
228- ov_ctx -> weight_data = os_malloc (weight .size );
229- if (!ov_ctx -> weight_data )
270+ graph -> weight_data = os_malloc (weight .size );
271+ if (!graph -> weight_data )
230272 goto fail ;
231- memcpy (ov_ctx -> weight_data , weight .buf , weight .size );
273+ memcpy (graph -> weight_data , weight .buf , weight .size );
232274
233275 ov_element_type_e type = U8 ;
234276 int64_t dims [1 ] = { weight .size };
235277 ov_shape_t shape = { 1 , dims };
236278 CHECK_OV_STATUS (ov_tensor_create_from_host_ptr (type , shape ,
237- ov_ctx -> weight_data ,
238- & ov_ctx -> weights_tensor ),
279+ graph -> weight_data ,
280+ & graph -> weights_tensor ),
239281 ret );
240282 }
241283
242284 /* load model from buffer */
243285 CHECK_OV_STATUS (ov_core_read_model_from_memory_buffer (
244286 ov_ctx -> core , (char * )xml .buf , xml .size ,
245- ov_ctx -> weights_tensor , & ov_ctx -> model ),
287+ graph -> weights_tensor , & graph -> model ),
246288 ret );
247289#ifndef NDEBUG
248290 print_model_input_output_info (ov_ctx -> model );
249291#endif
250292
251- ret = success ;
293+ CHECK_OV_STATUS (ov_core_compile_model (ov_ctx -> core , graph -> model , "CPU" , 0 ,
294+ & graph -> compiled_model ),
295+ ret );
296+
297+ * g = graph_idx ;
298+ ov_ctx -> n_graphs ++ ;
299+ return success ;
252300fail :
301+ free_graph (graph );
253302 return ret ;
254303}
255304
256305__attribute__((visibility ("default" ))) wasi_nn_error
257306load_by_name (void * ctx , const char * filename , uint32_t filename_len , graph * g )
258307{
259308 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
309+ struct OpenVINOGraph * graph ;
310+ unsigned int graph_idx ;
260311 wasi_nn_error ret = unsupported_operation ;
261312
313+ graph_idx = ov_ctx -> n_graphs ;
314+ if (graph_idx >= MAX_GRAPHS ) {
315+ return runtime_error ;
316+ }
317+ graph = & ov_ctx -> graphs [graph_idx ];
318+
319+ memset (graph , 0 , sizeof (* graph ));
262320 CHECK_OV_STATUS (
263- ov_core_read_model (ov_ctx -> core , filename , NULL , & ov_ctx -> model ), ret );
321+ ov_core_read_model (ov_ctx -> core , filename , NULL , & graph -> model ), ret );
264322
265- ret = success ;
323+ CHECK_OV_STATUS (ov_core_compile_model (ov_ctx -> core , graph -> model , "CPU" , 0 ,
324+ & graph -> compiled_model ),
325+ ret );
326+
327+ * g = graph_idx ;
328+ ov_ctx -> n_graphs ++ ;
329+ return success ;
266330fail :
331+ free_graph (graph );
267332 return ret ;
268333}
269334
270335__attribute__((visibility ("default" ))) wasi_nn_error
271336init_execution_context (void * ctx , graph g , graph_execution_context * exec_ctx )
272337{
338+ OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
339+ struct OpenVINOGraph * graph ;
340+ struct OpenVINOExecutionContext * exec ;
341+ unsigned int exec_idx ;
342+ wasi_nn_error ret ;
343+
344+ if (g >= ov_ctx -> n_graphs )
345+ return runtime_error ;
346+ graph = & ov_ctx -> graphs [g ];
347+
348+ exec_idx = ov_ctx -> n_execution_contexts ;
349+ if (exec_idx >= MAX_EXECUTION_CONTEXTS )
350+ return runtime_error ;
351+ exec = & ov_ctx -> execution_contexts [exec_idx ];
352+
353+ memset (exec , 0 , sizeof (* exec ));
354+ exec -> graph = graph ;
355+
356+ CHECK_OV_STATUS (ov_compiled_model_create_infer_request (
357+ graph -> compiled_model , & exec -> infer_request ),
358+ ret );
359+
360+ * exec_ctx = exec_idx ;
361+ ov_ctx -> n_execution_contexts ++ ;
273362 return success ;
363+ fail :
364+ return ret ;
274365}
275366
276367__attribute__((visibility ("default" ))) wasi_nn_error
277368set_input (void * ctx , graph_execution_context exec_ctx , uint32_t index ,
278369 tensor * wasi_nn_tensor )
279370{
280371 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
372+ struct OpenVINOExecutionContext * exec ;
281373 wasi_nn_error ret = unsupported_operation ;
282374 ov_shape_t input_shape = { 0 };
375+ ov_tensor_t * input_tensor = NULL ;
283376 int64_t * ov_dims = NULL ;
284377
378+ if (exec_ctx >= ov_ctx -> n_execution_contexts )
379+ return runtime_error ;
380+ exec = & ov_ctx -> execution_contexts [exec_ctx ];
381+
285382 /* wasi_nn_tensor -> ov_tensor */
286383 {
287384 ret = uint32_array_to_int64_array (wasi_nn_tensor -> dimensions -> size ,
@@ -306,27 +403,20 @@ set_input(void *ctx, graph_execution_context exec_ctx, uint32_t index,
306403
307404 CHECK_OV_STATUS (ov_tensor_create_from_host_ptr (input_type , input_shape ,
308405 wasi_nn_tensor -> data ,
309- & ov_ctx -> input_tensor ),
406+ & input_tensor ),
310407 ret );
311408 }
312409
313- CHECK_OV_STATUS (ov_core_compile_model (ov_ctx -> core , ov_ctx -> model , "CPU" , 0 ,
314- & ov_ctx -> compiled_model ),
315- ret );
316-
317- CHECK_OV_STATUS (ov_compiled_model_create_infer_request (
318- ov_ctx -> compiled_model , & ov_ctx -> infer_request ),
319- ret );
320-
321410 /* install ov_tensor -> infer_request */
322411 CHECK_OV_STATUS (ov_infer_request_set_input_tensor_by_index (
323- ov_ctx -> infer_request , index , ov_ctx -> input_tensor ),
412+ exec -> infer_request , index , input_tensor ),
324413 ret );
325414 ret = success ;
326-
327415fail :
328416 if (ov_dims )
329417 os_free (ov_dims );
418+ if (input_tensor )
419+ ov_tensor_free (input_tensor );
330420 ov_shape_free (& input_shape );
331421
332422 return ret ;
@@ -336,9 +426,14 @@ __attribute__((visibility("default"))) wasi_nn_error
336426compute (void * ctx , graph_execution_context exec_ctx )
337427{
338428 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
429+ struct OpenVINOExecutionContext * exec ;
339430 wasi_nn_error ret = unsupported_operation ;
340431
341- CHECK_OV_STATUS (ov_infer_request_infer (ov_ctx -> infer_request ), ret );
432+ if (exec_ctx >= ov_ctx -> n_execution_contexts )
433+ return runtime_error ;
434+ exec = & ov_ctx -> execution_contexts [exec_ctx ];
435+
436+ CHECK_OV_STATUS (ov_infer_request_infer (exec -> infer_request ), ret );
342437 ret = success ;
343438fail :
344439 return ret ;
@@ -349,13 +444,18 @@ get_output(void *ctx, graph_execution_context exec_ctx, uint32_t index,
349444 tensor_data output_tensor , uint32_t * output_tensor_size )
350445{
351446 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
447+ struct OpenVINOExecutionContext * exec ;
352448 wasi_nn_error ret = unsupported_operation ;
353449 ov_tensor_t * ov_tensor = NULL ;
354450 void * data = NULL ;
355451 size_t byte_size = 0 ;
356452
453+ if (exec_ctx >= ov_ctx -> n_execution_contexts )
454+ return runtime_error ;
455+ exec = & ov_ctx -> execution_contexts [exec_ctx ];
456+
357457 CHECK_OV_STATUS (ov_infer_request_get_output_tensor_by_index (
358- ov_ctx -> infer_request , index , & ov_tensor ),
458+ exec -> infer_request , index , & ov_tensor ),
359459 ret );
360460
361461 CHECK_OV_STATUS (ov_tensor_get_byte_size (ov_tensor , & byte_size ), ret );
@@ -421,27 +521,16 @@ __attribute__((visibility("default"))) wasi_nn_error
421521deinit_backend (void * ctx )
422522{
423523 OpenVINOContext * ov_ctx = (OpenVINOContext * )ctx ;
524+ unsigned int i ;
424525
425526 if (!ov_ctx )
426527 return invalid_argument ;
427528
428- if (ov_ctx -> weight_data )
429- os_free (ov_ctx -> weight_data );
430-
431- if (ov_ctx -> weights_tensor )
432- ov_tensor_free (ov_ctx -> weights_tensor );
433-
434- if (ov_ctx -> input_tensor )
435- ov_tensor_free (ov_ctx -> input_tensor );
436-
437- if (ov_ctx -> infer_request )
438- ov_infer_request_free (ov_ctx -> infer_request );
439-
440- if (ov_ctx -> compiled_model )
441- ov_compiled_model_free (ov_ctx -> compiled_model );
529+ for (i = 0 ; i < ov_ctx -> n_execution_contexts ; i ++ )
530+ free_execution_context (& ov_ctx -> execution_contexts [i ]);
442531
443- if ( ov_ctx -> model )
444- ov_model_free ( ov_ctx -> model );
532+ for ( i = 0 ; i < ov_ctx -> n_graphs ; i ++ )
533+ free_graph ( & ov_ctx -> graphs [ i ] );
445534
446535 if (ov_ctx -> core )
447536 ov_core_free (ov_ctx -> core );
0 commit comments