@@ -378,9 +378,13 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
378378 }
379379}
380380
381- static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view ) {
381+ static void init_view (struct ggml_allocr * alloc , struct ggml_tensor * view , bool update_backend ) {
382382 assert (view -> view_src != NULL && view -> view_src -> data != NULL );
383- view -> backend = view -> view_src -> backend ;
383+
384+ if (update_backend ) {
385+ view -> backend = view -> view_src -> backend ;
386+ }
387+
384388 view -> buffer = view -> view_src -> buffer ;
385389 view -> data = (char * )view -> view_src -> data + view -> view_offs ;
386390
@@ -394,7 +398,7 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
394398 struct hash_node * ht = alloc -> hash_table ;
395399 if (node -> data == NULL ) {
396400 if (ggml_is_view (node )) {
397- init_view (alloc , node );
401+ init_view (alloc , node , true );
398402 } else {
399403 // see if we can reuse a parent's buffer (inplace)
400404 if (ggml_op_can_inplace (node -> op )) {
@@ -424,15 +428,14 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
424428 AT_PRINTF ("reusing view parent %s (%s) for %s\n" , parent -> name , view_src -> name , node -> name );
425429 node -> view_src = view_src ;
426430 view_src_hn -> n_views += 1 ;
427- init_view (alloc , node );
431+ init_view (alloc , node , false );
428432 return ;
429433 }
430- }
431- else {
434+ } else {
432435 AT_PRINTF ("reusing parent %s for %s\n" , parent -> name , node -> name );
433436 node -> view_src = parent ;
434437 p_hn -> n_views += 1 ;
435- init_view (alloc , node );
438+ init_view (alloc , node , false );
436439 return ;
437440 }
438441 }
@@ -463,7 +466,7 @@ size_t ggml_allocr_alloc_graph_n(
463466 hash_get (ht , view_src )-> n_views += 1 ;
464467 if (node -> buffer == NULL && node -> data != NULL ) {
465468 // view of a pre-allocated tensor, didn't call init_view() yet
466- init_view (alloc , node );
469+ init_view (alloc , node , true );
467470 }
468471 }
469472
@@ -474,7 +477,7 @@ size_t ggml_allocr_alloc_graph_n(
474477 }
475478 hash_get (ht , parent )-> n_children += 1 ;
476479 if (ggml_is_view (parent ) && parent -> buffer == NULL && parent -> data != NULL ) {
477- init_view (alloc , parent );
480+ init_view (alloc , parent , true );
478481 }
479482 }
480483 }
0 commit comments