MPI 加农算法


MPI 加农算法计算矩阵相乘 为了实现方便 只实现了矩阵自乘 而且 实际上各线程都可读取整个矩阵(略去输入和分配中的通信过程)

在各线程相互交换数据的过程中 一定要注意避免死锁

  1. #include "mpi.h"   
  2. #include <stdio.h>   
  3. #include <stdlib.h>   
  4. #include <math.h>   
  5.   
  6. #define TAG 0   
  7.   
  8. static inline void rtoij(int r,int w,int *i,int *j) {  
  9.     //rank to i,j   
  10.     *i=r/w;  
  11.     *j=r-(*i)*w;  
  12. }  
  13. static inline void ijtor(int i,int j,int *r,int w) {  
  14.     //i,j to rank   
  15.     *r=i*w+j;  
  16. }  
  17. static inline int aj(int i,int j,int w) {  
  18.     return (j+i)%w;  
  19. }  
  20. static inline int bi(int i,int j,int w) {  
  21.     //initial b distribute   
  22.     return aj(i,j,w);  
  23. }  
  24. int main(int argc,char *argv[]) {  
  25.     float m[][2]={{1,2},  
  26.              {3,4}};//the matrix    
  27.     //multiply by itself   
  28.     int self,size;  
  29.     MPI_Init(&argc,&argv);  
  30.     MPI_Comm_rank(MPI_COMM_WORLD,&self);  
  31.     MPI_Comm_size(MPI_COMM_WORLD,&size);  
  32.     MPI_Request r;  
  33.     MPI_Status s;  
  34.     int w=sqrt(size);  
  35.     int i,j,an,bn,ap,bp;  
  36.     rtoij(self,w,&i,&j);  
  37.     ijtor(i,(j+w-1)%w,&an,w);//next a process   
  38.     ijtor((i+w-1)%w,j,&bn,w);//next b process   
  39.     ijtor(i,(j+1)%w,&ap,w);//previous a process   
  40.     ijtor((i+1)%w,j,&bp,w);  
  41.   
  42.     float res,a,b,tmp;  
  43.     //initialize data distribution   
  44.     a=m[i][aj(i,j,w)];  
  45.     b=m[bi(i,j,w)][j];  
  46.     res=a*b;  
  47.   
  48.     for(int i=0;i<w-1;++i) {  
  49.         MPI_Issend(&a,1,MPI_FLOAT,an,TAG,MPI_COMM_WORLD,&r);//avoid dead lock   
  50.         MPI_Recv(&tmp,1,MPI_FLOAT,ap,TAG,MPI_COMM_WORLD,&s);  
  51.         a=tmp;  
  52.         MPI_Wait(&r,&s);  
  53.   
  54.         MPI_Issend(&b,1,MPI_FLOAT,bn,TAG,MPI_COMM_WORLD,&r);//avoid dead lock   
  55.         MPI_Recv(&tmp,1,MPI_FLOAT,bp,TAG,MPI_COMM_WORLD,&s);  
  56.         b=tmp;  
  57.         MPI_Wait(&r,&s);  
  58.   
  59.         res+=a*b;  
  60.     }  
  61.     if(0==self) {  
  62.         printf("%f",res);  
  63.         for(int i=1;i<size;++i) {  
  64.             MPI_Recv(&res,1,MPI_FLOAT,i,TAG,MPI_COMM_WORLD,&s);  
  65.             printf(" %f",res);  
  66.         }  
  67.         printf("\n");  
  68.     } else {  
  69.         MPI_Ssend(&res,1,MPI_FLOAT,0,TAG,MPI_COMM_WORLD);  
  70.     }  
  71.     MPI_Finalize();  
  72.     return 0;  
  73. }  

相关内容