12
12
* Copyright (c) 2008 Sun Microsystems, Inc. All rights reserved.
13
13
* Copyright (c) 2012 Oak Ridge National Labs. All rights reserved.
14
14
* Copyright (c) 2012 Sandia National Laboratories. All rights reserved.
15
- * Copyright (c) 2014-2015 Research Organization for Information Science
15
+ * Copyright (c) 2014-2016 Research Organization for Information Science
16
16
* and Technology (RIST). All rights reserved.
17
17
* $COPYRIGHT$
18
18
*
@@ -58,7 +58,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
58
58
mca_coll_base_module_t * module )
59
59
{
60
60
int rank , size , count , err = OMPI_SUCCESS ;
61
- ptrdiff_t extent , buf_size , gap ;
61
+ ptrdiff_t gap , span ;
62
62
char * recv_buf = NULL , * recv_buf_free = NULL ;
63
63
64
64
/* Initialize */
@@ -72,8 +72,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
72
72
}
73
73
74
74
/* get datatype information */
75
- ompi_datatype_type_extent (dtype , & extent );
76
- buf_size = opal_datatype_span (& dtype -> super , count , & gap );
75
+ span = opal_datatype_span (& dtype -> super , count , & gap );
77
76
78
77
/* Handle MPI_IN_PLACE */
79
78
if (MPI_IN_PLACE == sbuf ) {
@@ -83,12 +82,12 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
83
82
if (0 == rank ) {
84
83
/* temporary receive buffer. See coll_basic_reduce.c for
85
84
details on sizing */
86
- recv_buf_free = (char * ) malloc (buf_size );
87
- recv_buf = recv_buf_free - gap ;
85
+ recv_buf_free = (char * ) malloc (span );
88
86
if (NULL == recv_buf_free ) {
89
87
err = OMPI_ERR_OUT_OF_RESOURCE ;
90
88
goto cleanup ;
91
89
}
90
+ recv_buf = recv_buf_free - gap ;
92
91
}
93
92
94
93
/* reduction */
@@ -126,8 +125,9 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
126
125
{
127
126
int err , i , rank , root = 0 , rsize , lsize ;
128
127
int totalcounts ;
129
- ptrdiff_t lb , extent ;
128
+ ptrdiff_t gap , span ;
130
129
char * tmpbuf = NULL , * tmpbuf2 = NULL ;
130
+ char * lbuf , * buf ;
131
131
ompi_request_t * req ;
132
132
133
133
rank = ompi_comm_rank (comm );
@@ -151,16 +151,15 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
151
151
*
152
152
*/
153
153
if (rank == root ) {
154
- err = ompi_datatype_get_extent (dtype , & lb , & extent );
155
- if (OMPI_SUCCESS != err ) {
156
- return OMPI_ERROR ;
157
- }
154
+ span = opal_datatype_span (& dtype -> super , totalcounts , & gap );
158
155
159
- tmpbuf = (char * ) malloc (totalcounts * extent );
160
- tmpbuf2 = (char * ) malloc (totalcounts * extent );
156
+ tmpbuf = (char * ) malloc (span );
157
+ tmpbuf2 = (char * ) malloc (span );
161
158
if (NULL == tmpbuf || NULL == tmpbuf2 ) {
162
159
return OMPI_ERR_OUT_OF_RESOURCE ;
163
160
}
161
+ lbuf = tmpbuf - gap ;
162
+ buf = tmpbuf2 - gap ;
164
163
165
164
/* Do a send-recv between the two root procs. to avoid deadlock */
166
165
err = MCA_PML_CALL (isend (sbuf , totalcounts , dtype , 0 ,
@@ -170,7 +169,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
170
169
goto exit ;
171
170
}
172
171
173
- err = MCA_PML_CALL (recv (tmpbuf2 , totalcounts , dtype , 0 ,
172
+ err = MCA_PML_CALL (recv (lbuf , totalcounts , dtype , 0 ,
174
173
MCA_COLL_BASE_TAG_REDUCE_SCATTER , comm ,
175
174
MPI_STATUS_IGNORE ));
176
175
if (OMPI_SUCCESS != err ) {
@@ -188,15 +187,18 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
188
187
* tmpbuf2.
189
188
*/
190
189
for (i = 1 ; i < rsize ; i ++ ) {
191
- err = MCA_PML_CALL (recv (tmpbuf , totalcounts , dtype , i ,
190
+ char * tbuf ;
191
+ err = MCA_PML_CALL (recv (buf , totalcounts , dtype , i ,
192
192
MCA_COLL_BASE_TAG_REDUCE_SCATTER , comm ,
193
193
MPI_STATUS_IGNORE ));
194
194
if (MPI_SUCCESS != err ) {
195
195
goto exit ;
196
196
}
197
197
198
198
/* Perform the reduction */
199
- ompi_op_reduce (op , tmpbuf , tmpbuf2 , totalcounts , dtype );
199
+ ompi_op_reduce (op , lbuf , buf , totalcounts , dtype );
200
+ /* swap the buffers */
201
+ tbuf = lbuf ; lbuf = buf ; buf = tbuf ;
200
202
}
201
203
} else {
202
204
/* If not root, send data to the root. */
@@ -209,7 +211,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
209
211
}
210
212
211
213
/* Now do a scatterv on the local communicator */
212
- err = comm -> c_local_comm -> c_coll .coll_scatter (tmpbuf2 , rcount , dtype ,
214
+ err = comm -> c_local_comm -> c_coll .coll_scatter (lbuf , rcount , dtype ,
213
215
rbuf , rcount , dtype , 0 ,
214
216
comm -> c_local_comm ,
215
217
comm -> c_local_comm -> c_coll .coll_scatter_module );
0 commit comments