@@ -640,7 +640,9 @@ GGML_CALL static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_
640640}
641641
642642GGML_CALL static bool ggml_backend_cpu_buffer_type_supports_backend (ggml_backend_buffer_type_t buft , ggml_backend_t backend ) {
643- return ggml_backend_is_cpu (backend );
643+ // HACK
644+ static ggml_guid blas_guid = { 0x12 , 0xa8 , 0xae , 0xf4 , 0xc0 , 0x1e , 0x61 , 0x97 , 0x8f , 0xeb , 0x33 , 0x04 , 0xa1 , 0x33 , 0x51 , 0x2d };
645+ return ggml_backend_is_cpu (backend ) || ggml_guid_matches (backend -> guid , & blas_guid );
644646
645647 GGML_UNUSED (buft );
646648}
@@ -1097,15 +1099,16 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen
10971099 return -1 ;
10981100}
10991101
1100- static int ggml_backend_sched_backend_from_buffer (ggml_backend_sched_t sched , const struct ggml_tensor * tensor ) {
1102+ static int ggml_backend_sched_backend_from_buffer (ggml_backend_sched_t sched , const struct ggml_tensor * tensor , const struct ggml_tensor * op ) {
11011103 ggml_backend_buffer_t buffer = tensor -> buffer ;
11021104 if (buffer == NULL ) {
11031105 return -1 ;
11041106 }
11051107
1106- // find highest prio backend that supports the buffer type
1108+ // find highest prio backend that supports the buffer type and the op
11071109 for (int i = 0 ; i < sched -> n_backends ; i ++ ) {
1108- if (ggml_backend_buft_supports_backend (buffer -> buft , sched -> backends [i ])) {
1110+ if (ggml_backend_buft_supports_backend (buffer -> buft , sched -> backends [i ]) &&
1111+ ggml_backend_supports_op (sched -> backends [i ], op )) {
11091112 return i ;
11101113 }
11111114 }
@@ -1126,20 +1129,25 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS*GGML_SCHED
11261129#define GET_CAUSE (node ) ""
11271130#endif
11281131
1132+ //#define DEBUG_PASS1
1133+ //#define DEBUG_PASS2
1134+ //#define DEBUG_PASS3
1135+ //#define DEBUG_PASS4
1136+
11291137// returns the backend that should be used for the node based on the current locations
11301138static int ggml_backend_sched_backend_id_from_cur (ggml_backend_sched_t sched , struct ggml_tensor * tensor ) {
11311139 // TODO: use supports_op to check if the backend supports the op
11321140
11331141 // assign pre-allocated nodes to their backend
1134- int cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor );
1142+ int cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor , tensor );
11351143 if (cur_backend_id != -1 ) {
11361144 SET_CAUSE (tensor , "1.dst" );
11371145 return cur_backend_id ;
11381146 }
11391147
11401148 // view_src
11411149 if (tensor -> view_src != NULL ) {
1142- cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor -> view_src );
1150+ cur_backend_id = ggml_backend_sched_backend_from_buffer (sched , tensor -> view_src , tensor );
11431151 if (cur_backend_id != -1 ) {
11441152 SET_CAUSE (tensor , "1.vsrc" );
11451153 return cur_backend_id ;
@@ -1161,7 +1169,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
11611169 continue ;
11621170 }
11631171 if (src -> buffer != NULL && src -> buffer -> usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS ) {
1164- int src_backend_id = ggml_backend_sched_backend_from_buffer (sched , src );
1172+ int src_backend_id = ggml_backend_sched_backend_from_buffer (sched , src , tensor );
11651173 // check if a backend with higher prio wants to offload the op
11661174 if (src_backend_id == sched -> n_backends - 1 ) {
11671175 for (int b = 0 ; b < src_backend_id ; b ++ ) {
@@ -1223,10 +1231,30 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
12231231 }
12241232}
12251233
1226- //#define DEBUG_PASS1
1227- //#define DEBUG_PASS2
1228- //#define DEBUG_PASS3
1229- //#define DEBUG_PASS4
1234+ static int set_if_supports (ggml_backend_sched_t sched , struct ggml_tensor * node , int cur_backend_id , int * node_backend_id ) {
1235+ if (ggml_backend_supports_op (sched -> backends [cur_backend_id ], node )) {
1236+ * node_backend_id = cur_backend_id ;
1237+ SET_CAUSE (node , "2.2" );
1238+ } else {
1239+ for (int b = 0 ; b < sched -> n_backends ; b ++ ) {
1240+ if (b == cur_backend_id ) {
1241+ continue ;
1242+ }
1243+ if (ggml_backend_supports_op (sched -> backends [b ], node )) {
1244+ * node_backend_id = b ;
1245+ cur_backend_id = b ;
1246+ SET_CAUSE (node , "2.2" );
1247+ break ;
1248+ }
1249+ }
1250+ }
1251+ return cur_backend_id ;
1252+ }
1253+
1254+ static bool buffer_supported (ggml_backend_sched_t sched , const struct ggml_tensor * t , int cur_backend_id ) {
1255+ ggml_backend_buffer_t buf = t -> view_src ? t -> view_src -> buffer : t -> buffer ;
1256+ return buf != NULL && ggml_backend_buft_supports_backend (buf -> buft , sched -> backends [cur_backend_id ]);
1257+ }
12301258
12311259// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
12321260static void ggml_backend_sched_split_graph (ggml_backend_sched_t sched , struct ggml_cgraph * graph ) {
@@ -1306,9 +1334,13 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13061334 } else {
13071335 cur_backend_id = * node_backend_id ;
13081336 }
1309- } else {
1310- * node_backend_id = cur_backend_id ;
1311- SET_CAUSE (node , "2.2" );
1337+ } else if (cur_backend_id != -1 ) {
1338+ // FIXME: clean this
1339+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1340+ if (cur_backend_id == sched -> n_backends - 1 ) {
1341+ // skip cpu (lowest prio backend)
1342+ cur_backend_id = -1 ;
1343+ }
13121344 }
13131345 }
13141346 }
@@ -1328,9 +1360,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13281360 } else {
13291361 cur_backend_id = * node_backend_id ;
13301362 }
1331- } else {
1332- * node_backend_id = cur_backend_id ;
1333- SET_CAUSE (node , "2.1" );
1363+ } else if (cur_backend_id != -1 ) {
1364+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
1365+ if (cur_backend_id == sched -> n_backends - 1 ) {
1366+ // skip cpu (lowest prio backend)
1367+ cur_backend_id = -1 ;
1368+ }
13341369 }
13351370 }
13361371 }
@@ -1345,9 +1380,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13451380 int * node_backend_id = & tensor_backend_id (node );
13461381 if (* node_backend_id != -1 ) {
13471382 cur_backend_id = * node_backend_id ;
1348- } else {
1349- * node_backend_id = cur_backend_id ;
1350- SET_CAUSE (node , "2.4" );
1383+ } else if (cur_backend_id != -1 ) {
1384+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
13511385 }
13521386 }
13531387 }
@@ -1362,9 +1396,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
13621396 int * node_backend_id = & tensor_backend_id (node );
13631397 if (* node_backend_id != -1 ) {
13641398 cur_backend_id = * node_backend_id ;
1365- } else {
1366- * node_backend_id = cur_backend_id ;
1367- SET_CAUSE (node , "2.3" );
1399+ } else if (cur_backend_id != -1 ) {
1400+ cur_backend_id = set_if_supports (sched , node , cur_backend_id , node_backend_id );
13681401 }
13691402 }
13701403 }
@@ -1448,10 +1481,12 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
14481481 }
14491482 }
14501483 // check if the split has too many inputs
1484+ // FIXME: count the number of inputs instead of only checking when full
14511485 if (split -> n_inputs == GGML_SCHED_MAX_SPLIT_INPUTS ) {
14521486 const size_t id = hash_id (src );
14531487 int src_backend_id = sched -> tensor_backend_id [id ];
1454- if (src_backend_id != cur_backend_id && sched -> tensor_copies [hash_id (src )][cur_backend_id ][0 ] == NULL ) {
1488+ bool supported = buffer_supported (sched , src , cur_backend_id );
1489+ if (src_backend_id != cur_backend_id && sched -> tensor_copies [hash_id (src )][cur_backend_id ][0 ] == NULL && !supported ) {
14551490 //printf("starting new split because of too many inputs: node %s, input %s\n", node->name, src->name);
14561491 need_new_split = true;
14571492 break ;
@@ -1511,7 +1546,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
15111546 }
15121547 }
15131548
1514- if (src_backend_id != node_backend_id ) {
1549+ bool supported = buffer_supported (sched , src , cur_backend_id );
1550+ if (src_backend_id != cur_backend_id && !supported ) {
15151551 // create a copy of the input in the split's backend
15161552 const size_t id = hash_id (src );
15171553 if (sched -> tensor_copies [id ][cur_backend_id ][0 ] == NULL ) {
0 commit comments