| 1 | #include "mpi.h"
|
|---|
| 2 | #include "stdlib.h"
|
|---|
| 3 | #include "string.h"
|
|---|
| 4 | #include "assert.h"
|
|---|
| 5 | #include "reduceScatter_optimized.h"
|
|---|
| 6 |
|
|---|
| 7 | $input int NP; // nprocs
|
|---|
| 8 | $input int N; // size of input data per proc
|
|---|
| 9 | $input int NB; // upper bound of the size of input data per proc
|
|---|
| 10 | $assume(0 < N && N < NB);
|
|---|
| 11 | $input double DATA[N * NP]; // arbitrary input data, N per proc
|
|---|
| 12 | $input int VEC[NP]; // arbitrary recv counts
|
|---|
| 13 | $assume($forall (int i : 0 .. NP-1) VEC[i] > 0);
|
|---|
| 14 |
|
|---|
| 15 | #ifdef _CHECK_MAX
|
|---|
| 16 | #define MPIOP MPI_MAX
|
|---|
| 17 | #else
|
|---|
| 18 | #define MPIOP MPI_SUM
|
|---|
| 19 | #endif
|
|---|
| 20 |
|
|---|
| 21 | int main() {
|
|---|
| 22 | double * sendbuf, * recvbuf;
|
|---|
| 23 | int rank, size;
|
|---|
| 24 |
|
|---|
| 25 | MPI_Init(NULL, NULL);
|
|---|
| 26 | MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
|---|
| 27 | MPI_Comm_size(MPI_COMM_WORLD, &size);
|
|---|
| 28 |
|
|---|
| 29 | // NP is nprocs:
|
|---|
| 30 | $assume(NP == size);
|
|---|
| 31 |
|
|---|
| 32 | int recvcnts[size];
|
|---|
| 33 | int total_counts = 0;
|
|---|
| 34 | int my_oft = 0;
|
|---|
| 35 |
|
|---|
| 36 | // initializes recv counts
|
|---|
| 37 | // computes the total count for completing the assumption over the recv counts
|
|---|
| 38 | // computes my offset
|
|---|
| 39 | for (int i = 0; i < size; i++) {
|
|---|
| 40 | recvcnts[i] = VEC[i];
|
|---|
| 41 | total_counts += VEC[i];
|
|---|
| 42 | if (i < rank) my_oft += recvcnts[i];
|
|---|
| 43 | }
|
|---|
| 44 | // assumption over the recv count:
|
|---|
| 45 | $assume(total_counts == N);
|
|---|
| 46 |
|
|---|
| 47 | sendbuf = (double*)malloc(sizeof(double) * N);
|
|---|
| 48 | recvbuf = (double*)malloc(sizeof(double) * recvcnts[rank]);
|
|---|
| 49 |
|
|---|
| 50 | memcpy(sendbuf, DATA + rank * N, sizeof(double) * N);
|
|---|
| 51 | reduce_scatter_double(sendbuf, recvbuf, recvcnts, MPIOP, MPI_COMM_WORLD);
|
|---|
| 52 |
|
|---|
| 53 | // assertions:
|
|---|
| 54 | #ifdef _CHECK_MAX
|
|---|
| 55 |
|
|---|
| 56 | for (int i = 0; i < recvcnts[rank]; i++)
|
|---|
| 57 | assert(recvbuf[i] >= sendbuf[my_oft + i]);
|
|---|
| 58 |
|
|---|
| 59 | #else
|
|---|
| 60 |
|
|---|
| 61 | for (int i = 0; i < recvcnts[rank]; i++) {
|
|---|
| 62 | double expect = 0;
|
|---|
| 63 |
|
|---|
| 64 | for (int j = 0; j < size; j++)
|
|---|
| 65 | expect += DATA[j * N + my_oft + i];
|
|---|
| 66 | assert(recvbuf[i] == expect);
|
|---|
| 67 | }
|
|---|
| 68 | #endif
|
|---|
| 69 |
|
|---|
| 70 | free(sendbuf);
|
|---|
| 71 | free(recvbuf);
|
|---|
| 72 | MPI_Finalize();
|
|---|
| 73 | }
|
|---|