Skip to content

Commit 488d037

Browse files
committed
coll/basic: fix non standard ddt handling
- correctly handle non zero lower bound ddt - correctly handle ddt with size > extent Thanks Yuki Matsumoto for the report
1 parent c06fb04 commit 488d037

File tree

3 files changed

+45
-39
lines changed

3 files changed

+45
-39
lines changed

ompi/mca/coll/basic/coll_basic_allgather.c

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* University of Stuttgart. All rights reserved.
1010
* Copyright (c) 2004-2005 The Regents of the University of California.
1111
* All rights reserved.
12-
* Copyright (c) 2014-2015 Research Organization for Information Science
12+
* Copyright (c) 2014-2016 Research Organization for Information Science
1313
* and Technology (RIST). All rights reserved.
1414
* $COPYRIGHT$
1515
*
@@ -48,8 +48,9 @@ mca_coll_basic_allgather_inter(const void *sbuf, int scount,
4848
mca_coll_base_module_t *module)
4949
{
5050
int rank, root = 0, size, rsize, err, i, line;
51-
char *tmpbuf = NULL, *ptmp;
52-
ptrdiff_t rlb, slb, rextent, sextent, incr;
51+
char *tmpbuf_free = NULL, *tmpbuf, *ptmp;
52+
ptrdiff_t rlb, rextent, incr;
53+
ptrdiff_t gap, span;
5354
ompi_request_t *req;
5455
ompi_request_t **reqs = NULL;
5556

@@ -75,8 +76,6 @@ mca_coll_basic_allgather_inter(const void *sbuf, int scount,
7576
/* receive a msg. from all other procs. */
7677
err = ompi_datatype_get_extent(rdtype, &rlb, &rextent);
7778
if (OMPI_SUCCESS != err) { line = __LINE__; goto exit; }
78-
err = ompi_datatype_get_extent(sdtype, &slb, &sextent);
79-
if (OMPI_SUCCESS != err) { line = __LINE__; goto exit; }
8079

8180
/* Get a requests arrays of the right size */
8281
reqs = coll_base_comm_get_reqs(module->base_data, rsize + 1);
@@ -107,8 +106,10 @@ mca_coll_basic_allgather_inter(const void *sbuf, int scount,
107106
if (OMPI_SUCCESS != err) { line = __LINE__; goto exit; }
108107

109108
/* Step 2: exchange the resuts between the root processes */
110-
tmpbuf = (char *) malloc(scount * size * sextent);
111-
if (NULL == tmpbuf) { line = __LINE__; err = OMPI_ERR_OUT_OF_RESOURCE; goto exit; }
109+
span = opal_datatype_span(&sdtype->super, scount * size, &gap);
110+
tmpbuf_free = (char *) malloc(span);
111+
if (NULL == tmpbuf_free) { line = __LINE__; err = OMPI_ERR_OUT_OF_RESOURCE; goto exit; }
112+
tmpbuf = tmpbuf_free - gap;
112113

113114
err = MCA_PML_CALL(isend(rbuf, rsize * rcount, rdtype, 0,
114115
MCA_COLL_BASE_TAG_ALLGATHER,
@@ -158,8 +159,8 @@ mca_coll_basic_allgather_inter(const void *sbuf, int scount,
158159
(void)line; // silence compiler warning
159160
if( NULL != reqs ) ompi_coll_base_free_reqs(reqs, rsize+1);
160161
}
161-
if (NULL != tmpbuf) {
162-
free(tmpbuf);
162+
if (NULL != tmpbuf_free) {
163+
free(tmpbuf_free);
163164
}
164165

165166
return err;

ompi/mca/coll/basic/coll_basic_reduce_scatter.c

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* Copyright (c) 2012 Oak Ridge National Labs. All rights reserved.
1515
* Copyright (c) 2013 Los Alamos National Security, LLC. All rights
1616
* reserved.
17-
* Copyright (c) 2014-2015 Research Organization for Information Science
17+
* Copyright (c) 2014-2016 Research Organization for Information Science
1818
* and Technology (RIST). All rights reserved.
1919
* $COPYRIGHT$
2020
*
@@ -367,8 +367,9 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
367367
{
368368
int err, i, rank, root = 0, rsize, lsize;
369369
int totalcounts;
370-
ptrdiff_t lb, extent;
370+
ptrdiff_t gap, span;
371371
char *tmpbuf = NULL, *tmpbuf2 = NULL;
372+
char *lbuf, *buf;
372373
ompi_request_t *req;
373374
int *disps = NULL;
374375

@@ -399,10 +400,7 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
399400
* its size is the same as the local communicator size.
400401
*/
401402
if (rank == root) {
402-
err = ompi_datatype_get_extent(dtype, &lb, &extent);
403-
if (OMPI_SUCCESS != err) {
404-
return OMPI_ERROR;
405-
}
403+
span = opal_datatype_span(&dtype->super, totalcounts, &gap);
406404

407405
/* Generate displacements for the scatterv part */
408406
disps = (int*) malloc(sizeof(int) * lsize);
@@ -414,12 +412,14 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
414412
disps[i + 1] = disps[i] + rcounts[i];
415413
}
416414

417-
tmpbuf = (char *) malloc(totalcounts * extent);
418-
tmpbuf2 = (char *) malloc(totalcounts * extent);
415+
tmpbuf = (char *) malloc(span);
416+
tmpbuf2 = (char *) malloc(span);
419417
if (NULL == tmpbuf || NULL == tmpbuf2) {
420418
err = OMPI_ERR_OUT_OF_RESOURCE;
421419
goto exit;
422420
}
421+
lbuf = tmpbuf - gap;
422+
buf = tmpbuf2 - gap;
423423

424424
/* Do a send-recv between the two root procs. to avoid deadlock */
425425
err = MCA_PML_CALL(isend(sbuf, totalcounts, dtype, 0,
@@ -429,7 +429,7 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
429429
goto exit;
430430
}
431431

432-
err = MCA_PML_CALL(recv(tmpbuf2, totalcounts, dtype, 0,
432+
err = MCA_PML_CALL(recv(lbuf, totalcounts, dtype, 0,
433433
MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
434434
MPI_STATUS_IGNORE));
435435
if (OMPI_SUCCESS != err) {
@@ -444,18 +444,21 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
444444

445445
/* Loop receiving and calling reduction function (C or Fortran)
446446
* The result of this reduction operations is then in
447-
* tmpbuf2.
447+
* lbuf.
448448
*/
449449
for (i = 1; i < rsize; i++) {
450-
err = MCA_PML_CALL(recv(tmpbuf, totalcounts, dtype, i,
450+
char *tbuf;
451+
err = MCA_PML_CALL(recv(buf, totalcounts, dtype, i,
451452
MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
452453
MPI_STATUS_IGNORE));
453454
if (MPI_SUCCESS != err) {
454455
goto exit;
455456
}
456457

457458
/* Perform the reduction */
458-
ompi_op_reduce(op, tmpbuf, tmpbuf2, totalcounts, dtype);
459+
ompi_op_reduce(op, lbuf, buf, totalcounts, dtype);
460+
/* swap the buffers */
461+
tbuf = lbuf; lbuf = buf; buf = tbuf;
459462
}
460463
} else {
461464
/* If not root, send data to the root. */
@@ -468,7 +471,7 @@ mca_coll_basic_reduce_scatter_inter(const void *sbuf, void *rbuf, const int *rco
468471
}
469472

470473
/* Now do a scatterv on the local communicator */
471-
err = comm->c_local_comm->c_coll.coll_scatterv(tmpbuf2, rcounts, disps, dtype,
474+
err = comm->c_local_comm->c_coll.coll_scatterv(lbuf, rcounts, disps, dtype,
472475
rbuf, rcounts[rank], dtype, 0,
473476
comm->c_local_comm,
474477
comm->c_local_comm->c_coll.coll_scatterv_module);

ompi/mca/coll/basic/coll_basic_reduce_scatter_block.c

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
* Copyright (c) 2008 Sun Microsystems, Inc. All rights reserved.
1313
* Copyright (c) 2012 Oak Ridge National Labs. All rights reserved.
1414
* Copyright (c) 2012 Sandia National Laboratories. All rights reserved.
15-
* Copyright (c) 2014-2015 Research Organization for Information Science
15+
* Copyright (c) 2014-2016 Research Organization for Information Science
1616
* and Technology (RIST). All rights reserved.
1717
* $COPYRIGHT$
1818
*
@@ -58,7 +58,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
5858
mca_coll_base_module_t *module)
5959
{
6060
int rank, size, count, err = OMPI_SUCCESS;
61-
ptrdiff_t extent, buf_size, gap;
61+
ptrdiff_t gap, span;
6262
char *recv_buf = NULL, *recv_buf_free = NULL;
6363

6464
/* Initialize */
@@ -72,8 +72,7 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
7272
}
7373

7474
/* get datatype information */
75-
ompi_datatype_type_extent(dtype, &extent);
76-
buf_size = opal_datatype_span(&dtype->super, count, &gap);
75+
span = opal_datatype_span(&dtype->super, count, &gap);
7776

7877
/* Handle MPI_IN_PLACE */
7978
if (MPI_IN_PLACE == sbuf) {
@@ -83,12 +82,12 @@ mca_coll_basic_reduce_scatter_block_intra(const void *sbuf, void *rbuf, int rcou
8382
if (0 == rank) {
8483
/* temporary receive buffer. See coll_basic_reduce.c for
8584
details on sizing */
86-
recv_buf_free = (char*) malloc(buf_size);
87-
recv_buf = recv_buf_free - gap;
85+
recv_buf_free = (char*) malloc(span);
8886
if (NULL == recv_buf_free) {
8987
err = OMPI_ERR_OUT_OF_RESOURCE;
9088
goto cleanup;
9189
}
90+
recv_buf = recv_buf_free - gap;
9291
}
9392

9493
/* reduction */
@@ -126,8 +125,9 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
126125
{
127126
int err, i, rank, root = 0, rsize, lsize;
128127
int totalcounts;
129-
ptrdiff_t lb, extent;
128+
ptrdiff_t gap, span;
130129
char *tmpbuf = NULL, *tmpbuf2 = NULL;
130+
char *lbuf, *buf;
131131
ompi_request_t *req;
132132

133133
rank = ompi_comm_rank(comm);
@@ -151,16 +151,15 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
151151
*
152152
*/
153153
if (rank == root) {
154-
err = ompi_datatype_get_extent(dtype, &lb, &extent);
155-
if (OMPI_SUCCESS != err) {
156-
return OMPI_ERROR;
157-
}
154+
span = opal_datatype_span(&dtype->super, totalcounts, &gap);
158155

159-
tmpbuf = (char *) malloc(totalcounts * extent);
160-
tmpbuf2 = (char *) malloc(totalcounts * extent);
156+
tmpbuf = (char *) malloc(span);
157+
tmpbuf2 = (char *) malloc(span);
161158
if (NULL == tmpbuf || NULL == tmpbuf2) {
162159
return OMPI_ERR_OUT_OF_RESOURCE;
163160
}
161+
lbuf = tmpbuf - gap;
162+
buf = tmpbuf2 - gap;
164163

165164
/* Do a send-recv between the two root procs. to avoid deadlock */
166165
err = MCA_PML_CALL(isend(sbuf, totalcounts, dtype, 0,
@@ -170,7 +169,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
170169
goto exit;
171170
}
172171

173-
err = MCA_PML_CALL(recv(tmpbuf2, totalcounts, dtype, 0,
172+
err = MCA_PML_CALL(recv(lbuf, totalcounts, dtype, 0,
174173
MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
175174
MPI_STATUS_IGNORE));
176175
if (OMPI_SUCCESS != err) {
@@ -188,15 +187,18 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
188187
* tmpbuf2.
189188
*/
190189
for (i = 1; i < rsize; i++) {
191-
err = MCA_PML_CALL(recv(tmpbuf, totalcounts, dtype, i,
190+
char *tbuf;
191+
err = MCA_PML_CALL(recv(buf, totalcounts, dtype, i,
192192
MCA_COLL_BASE_TAG_REDUCE_SCATTER, comm,
193193
MPI_STATUS_IGNORE));
194194
if (MPI_SUCCESS != err) {
195195
goto exit;
196196
}
197197

198198
/* Perform the reduction */
199-
ompi_op_reduce(op, tmpbuf, tmpbuf2, totalcounts, dtype);
199+
ompi_op_reduce(op, lbuf, buf, totalcounts, dtype);
200+
/* swap the buffers */
201+
tbuf = lbuf; lbuf = buf; buf = tbuf;
200202
}
201203
} else {
202204
/* If not root, send data to the root. */
@@ -209,7 +211,7 @@ mca_coll_basic_reduce_scatter_block_inter(const void *sbuf, void *rbuf, int rcou
209211
}
210212

211213
/* Now do a scatterv on the local communicator */
212-
err = comm->c_local_comm->c_coll.coll_scatter(tmpbuf2, rcount, dtype,
214+
err = comm->c_local_comm->c_coll.coll_scatter(lbuf, rcount, dtype,
213215
rbuf, rcount, dtype, 0,
214216
comm->c_local_comm,
215217
comm->c_local_comm->c_coll.coll_scatter_module);

0 commit comments

Comments
 (0)