| 1 | #include <string.h>
|
|---|
| 2 | #include "allreduce.h"
|
|---|
| 3 |
|
|---|
| 4 | GS_DEFINE_DOM_SIZES()
|
|---|
| 5 |
|
|---|
| 6 | #define DEFINE_PROCS(T) \
|
|---|
| 7 | GS_FOR_EACH_OP(T,DEFINE_GATHER)
|
|---|
| 8 |
|
|---|
| 9 | GS_FOR_EACH_DOMAIN(DEFINE_PROCS)
|
|---|
| 10 |
|
|---|
| 11 | #undef DEFINE_PROCS
|
|---|
| 12 |
|
|---|
| 13 | void gs_gather_array(void *out, const void *in, uint n, gs_dom dom, gs_op op)
|
|---|
| 14 | {
|
|---|
| 15 | #define WITH_OP(T,OP) gather_array_##T##_##OP(out,in,n)
|
|---|
| 16 | #define WITH_DOMAIN(T) SWITCH_OP(T,op)
|
|---|
| 17 | SWITCH_DOMAIN(dom);
|
|---|
| 18 | #undef WITH_DOMAIN
|
|---|
| 19 | #undef WITH_OP
|
|---|
| 20 | }
|
|---|
| 21 |
|
|---|
| 22 | static void comm_send(const struct comm *c, void *p, size_t n,
|
|---|
| 23 | uint dst, int tag)
|
|---|
| 24 | {
|
|---|
| 25 | MPI_Send(p,n,MPI_UNSIGNED_CHAR,dst,tag,c->c);
|
|---|
| 26 | }
|
|---|
| 27 |
|
|---|
| 28 | static void comm_recv(const struct comm *c, void *p, size_t n,
|
|---|
| 29 | uint src, int tag)
|
|---|
| 30 | {
|
|---|
| 31 | MPI_Recv(p,n,MPI_UNSIGNED_CHAR,src,tag,c->c,MPI_STATUS_IGNORE);
|
|---|
| 32 | }
|
|---|
| 33 |
|
|---|
| 34 | static void allreduce_imp(const struct comm *com, gs_dom dom, gs_op op,
|
|---|
| 35 | void *v, uint vn, void *buf)
|
|---|
| 36 | {
|
|---|
| 37 | size_t total_size = vn*gs_dom_size[dom];
|
|---|
| 38 | const uint id=com->id, np=com->np;
|
|---|
| 39 | uint n = np, c=1, odd=0, base=0;
|
|---|
| 40 | while(n>1) {
|
|---|
| 41 | odd=(odd<<1)|(n&1);
|
|---|
| 42 | c<<=1, n>>=1;
|
|---|
| 43 | if(id>=base+n) c|=1, base+=n, n+=(odd&1);
|
|---|
| 44 | }
|
|---|
| 45 | while(n<np) {
|
|---|
| 46 | if(c&1) n-=(odd&1), base-=n;
|
|---|
| 47 | c>>=1, n<<=1, n+=(odd&1);
|
|---|
| 48 | odd>>=1;
|
|---|
| 49 | if(base==id) {
|
|---|
| 50 | comm_recv(com, buf,total_size, id+n/2,id+n/2);
|
|---|
| 51 | gs_gather_array(v,buf,vn, dom,op);
|
|---|
| 52 | } else {
|
|---|
| 53 | comm_send(com, v,total_size, base,id);
|
|---|
| 54 | break;
|
|---|
| 55 | }
|
|---|
| 56 | }
|
|---|
| 57 | while(n>1) {
|
|---|
| 58 | if(base==id)
|
|---|
| 59 | comm_send(com, v,total_size, id+n/2,id);
|
|---|
| 60 | else
|
|---|
| 61 | comm_recv(com, v,total_size, base,base);
|
|---|
| 62 | odd=(odd<<1)|(n&1);
|
|---|
| 63 | c<<=1, n>>=1;
|
|---|
| 64 | if(id>=base+n) c|=1, base+=n, n+=(odd&1);
|
|---|
| 65 | }
|
|---|
| 66 | }
|
|---|
| 67 |
|
|---|
| 68 | void comm_allreduce(const struct comm *com, gs_dom dom, gs_op op,
|
|---|
| 69 | void *v, uint vn, void *buf)
|
|---|
| 70 | {
|
|---|
| 71 | if(vn==0) return;
|
|---|
| 72 | MPI_Datatype mpitype;
|
|---|
| 73 | MPI_Op mpiop;
|
|---|
| 74 | #define DOMAIN_SWITCH() do { \
|
|---|
| 75 | switch(dom) { case gs_double: mpitype=MPI_DOUBLE; break; \
|
|---|
| 76 | case gs_float: mpitype=MPI_FLOAT; break; \
|
|---|
| 77 | case gs_int: mpitype=MPI_INT; break; \
|
|---|
| 78 | case gs_long: mpitype=MPI_LONG; break; \
|
|---|
| 79 | default: goto comm_allreduce_byhand; \
|
|---|
| 80 | } \
|
|---|
| 81 | } while(0)
|
|---|
| 82 |
|
|---|
| 83 | DOMAIN_SWITCH();
|
|---|
| 84 | #undef DOMAIN_SWITCH
|
|---|
| 85 | switch(op) { case gs_add: mpiop=MPI_SUM; break;
|
|---|
| 86 | case gs_mul: mpiop=MPI_PROD; break;
|
|---|
| 87 | case gs_min: mpiop=MPI_MIN; break;
|
|---|
| 88 | case gs_max: mpiop=MPI_MAX; break;
|
|---|
| 89 | default: goto comm_allreduce_byhand;
|
|---|
| 90 | }
|
|---|
| 91 | MPI_Allreduce(v,buf,vn,mpitype,mpiop,com->c);
|
|---|
| 92 | memcpy(v,buf,vn*gs_dom_size[dom]);
|
|---|
| 93 | return;
|
|---|
| 94 |
|
|---|
| 95 | comm_allreduce_byhand:
|
|---|
| 96 | allreduce_imp(com,dom,op, v,vn, buf);
|
|---|
| 97 | }
|
|---|