66 *
77 * Copyright (c) 2020 Cisco Systems, Inc. All rights reserved.
88 * Copyright (c) 2022 IBM Corporation. All rights reserved
9+ * Copyright (c) 2023 Computer Architecture and VLSI Systems (CARV)
10+ * Laboratory, ICS Forth. All rights reserved.
911 * $COPYRIGHT$
1012 *
1113 * Additional copyrights may follow
2224
2325#include "coll_han.h"
2426#include "ompi/mca/coll/base/coll_base_functions.h"
27+ #include "ompi/mca/coll/base/coll_base_util.h"
2528#include "ompi/mca/coll/base/coll_tags.h"
2629#include "ompi/mca/pml/pml.h"
2730#include "coll_han_trigger.h"
@@ -43,6 +46,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
4346 struct ompi_op_t * op ,
4447 int root_up_rank ,
4548 int root_low_rank ,
49+ int root_reduce_low_rank ,
4650 struct ompi_communicator_t * up_comm ,
4751 struct ompi_communicator_t * low_comm ,
4852 int num_segments ,
@@ -59,6 +63,7 @@ mca_coll_han_set_allreduce_args(mca_coll_han_allreduce_args_t * args,
5963 args -> op = op ;
6064 args -> root_up_rank = root_up_rank ;
6165 args -> root_low_rank = root_low_rank ;
66+ args -> root_reduce_low_rank = root_reduce_low_rank ;
6267 args -> up_comm = up_comm ;
6368 args -> low_comm = low_comm ;
6469 args -> num_segments = num_segments ;
@@ -139,15 +144,26 @@ mca_coll_han_allreduce_intra(const void *sbuf,
139144 int low_rank = ompi_comm_rank (low_comm );
140145 int root_up_rank = 0 ;
141146 int root_low_rank = 0 ;
147+ int root_reduce_low_rank = 0 ;
148+
149+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
150+ opal_list_get_last (low_comm -> c_coll -> module_list );
151+
152+ // Invoke XHC's "special" Reduce
153+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
154+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
155+ root_reduce_low_rank = -1 ;
156+ }
157+
142158 /* Create t0 task for the first segment */
143159 mca_coll_task_t * t0 = OBJ_NEW (mca_coll_task_t );
144160 /* Setup up t0 task arguments */
145161 int * completed = (int * ) malloc (sizeof (int ));
146162 completed [0 ] = 0 ;
147163 mca_coll_han_allreduce_args_t * t = malloc (sizeof (mca_coll_han_allreduce_args_t ));
148164 mca_coll_han_set_allreduce_args (t , t0 , (char * ) sbuf , (char * ) rbuf , seg_count , dtype , op ,
149- root_up_rank , root_low_rank , up_comm , low_comm , num_segments , 0 ,
150- w_rank , count - (num_segments - 1 ) * seg_count ,
165+ root_up_rank , root_low_rank , root_reduce_low_rank , up_comm ,
166+ low_comm , num_segments , 0 , w_rank , count - (num_segments - 1 ) * seg_count ,
151167 low_rank != root_low_rank , NULL , completed );
152168 /* Init t0 task */
153169 init_task (t0 , mca_coll_han_allreduce_t0_task , (void * ) (t ));
@@ -215,18 +231,18 @@ int mca_coll_han_allreduce_t0_task(void *task_args)
215231 if (MPI_IN_PLACE == t -> sbuf ) {
216232 if (!t -> noop ) {
217233 t -> low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
218- t -> op , t -> root_low_rank , t -> low_comm ,
234+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
219235 t -> low_comm -> c_coll -> coll_reduce_module );
220236 }
221237 else {
222238 t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> rbuf , NULL , t -> seg_count , t -> dtype ,
223- t -> op , t -> root_low_rank , t -> low_comm ,
239+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
224240 t -> low_comm -> c_coll -> coll_reduce_module );
225241 }
226242 }
227243 else {
228244 t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf , (char * ) t -> rbuf , t -> seg_count , t -> dtype ,
229- t -> op , t -> root_low_rank , t -> low_comm ,
245+ t -> op , t -> root_reduce_low_rank , t -> low_comm ,
230246 t -> low_comm -> c_coll -> coll_reduce_module );
231247 }
232248 return OMPI_SUCCESS ;
@@ -264,7 +280,7 @@ int mca_coll_han_allreduce_t1_task(void *task_args)
264280 }
265281 t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + extent * t -> seg_count ,
266282 (char * ) t -> rbuf + extent * t -> seg_count , tmp_count ,
267- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
283+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
268284 t -> low_comm -> c_coll -> coll_reduce_module );
269285
270286 }
@@ -323,7 +339,7 @@ int mca_coll_han_allreduce_t2_task(void *task_args)
323339 }
324340 t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 2 * extent * t -> seg_count ,
325341 (char * ) t -> rbuf + 2 * extent * t -> seg_count , tmp_count ,
326- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
342+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
327343 t -> low_comm -> c_coll -> coll_reduce_module );
328344 }
329345 if (!t -> noop && req_count > 0 ) {
@@ -387,7 +403,7 @@ int mca_coll_han_allreduce_t3_task(void *task_args)
387403 }
388404 t -> low_comm -> c_coll -> coll_reduce ((char * ) t -> sbuf + 3 * extent * t -> seg_count ,
389405 (char * ) t -> rbuf + 3 * extent * t -> seg_count , tmp_count ,
390- t -> dtype , t -> op , t -> root_low_rank , t -> low_comm ,
406+ t -> dtype , t -> op , t -> root_reduce_low_rank , t -> low_comm ,
391407 t -> low_comm -> c_coll -> coll_reduce_module );
392408 }
393409 /* lb of cur_seg */
@@ -421,6 +437,7 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
421437 ompi_communicator_t * low_comm ;
422438 ompi_communicator_t * up_comm ;
423439 int root_low_rank = 0 ;
440+ int root_reduce_low_rank = 0 ;
424441 int low_rank ;
425442 int ret ;
426443 mca_coll_han_module_t * han_module = (mca_coll_han_module_t * )module ;
@@ -452,22 +469,31 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
452469 up_comm = han_module -> sub_comm [INTER_NODE ];
453470 low_rank = ompi_comm_rank (low_comm );
454471
472+ mca_coll_base_avail_coll_t * low_1st_module = (mca_coll_base_avail_coll_t * )
473+ opal_list_get_last (low_comm -> c_coll -> module_list );
474+
475+ // Invoke XHC's "special" Reduce
476+ if (0 == strcmp (low_1st_module -> ac_component_name , "xhc" )
477+ && low_comm -> c_coll -> coll_reduce_module == low_1st_module -> ac_module ) {
478+ root_reduce_low_rank = -1 ;
479+ }
480+
455481 /* Low_comm reduce */
456482 if (MPI_IN_PLACE == sbuf ) {
457483 if (low_rank == root_low_rank ) {
458484 ret = low_comm -> c_coll -> coll_reduce (MPI_IN_PLACE , (char * )rbuf ,
459- count , dtype , op , root_low_rank ,
485+ count , dtype , op , root_reduce_low_rank ,
460486 low_comm , low_comm -> c_coll -> coll_reduce_module );
461487 }
462488 else {
463489 ret = low_comm -> c_coll -> coll_reduce ((char * )rbuf , NULL ,
464- count , dtype , op , root_low_rank ,
490+ count , dtype , op , root_reduce_low_rank ,
465491 low_comm , low_comm -> c_coll -> coll_reduce_module );
466492 }
467493 }
468494 else {
469495 ret = low_comm -> c_coll -> coll_reduce ((char * )sbuf , (char * )rbuf ,
470- count , dtype , op , root_low_rank ,
496+ count , dtype , op , root_reduce_low_rank ,
471497 low_comm , low_comm -> c_coll -> coll_reduce_module );
472498 }
473499 if (OPAL_UNLIKELY (OMPI_SUCCESS != ret )) {
0 commit comments