summaryrefslogtreecommitdiff
blob: 3b57d57d0e46a6c2139cad99b7176c9416f7c868 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//=====================================================
// Copyright (C) 2011 Andrea Arteaga <andyspiros@gmail.com>
//=====================================================
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 2
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
// You should have received a copy of the GNU General Public License
// along with this program; if not, write to the Free Software
// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
//
#define PRFX d
#define CAT_(x,y) x##y
#define CAT(x,y) CAT_(x,y)

#define FUNCNAME(name) CAT(CAT(TYPEPREFIX, name),_)
#define vector_t std::vector<TYPENAME>

#include <vector>


inline void gather(
    const int& context,                 // [IN]
    vector_t& GlobalMatrixVector,       // [OUT] Only relevant for root
    const vector_t& LocalMatrixVector,  // [IN]
    int& GlobalRows,                    // [OUT]
    int& GlobalCols,                    // [OUT]
    int& BlockRows,                     // [IN (root) / OUT (other)]
    int& BlockCols,                     // [IN (root) / OUT (other)]
    int& LocalRows,                     // [IN]
    int& LocalCols,                     // [IN]
    const int& rootrow = 0,             // [IN]
    const int& rootcol = 0              // [IN]
) {
    /* Helper variables */
    int iONE = 1, iTWO = 2, imONE = -1;

    int myid, myrow, mycol, procrows, proccols, procnum;
    blacs_pinfo_(&myid, &procnum);
    blacs_gridinfo_(&context, &procrows, &proccols, &myrow, &mycol);
    bool iamroot = (myrow == rootrow && mycol == rootcol);
    TYPENAME *GlobalMatrix;
    const TYPENAME *LocalMatrix = &LocalMatrixVector[0];

    /* Broadcast matrix info */
    int binfo[2];
    if (iamroot) {
        binfo[0] = BlockRows;
        binfo[1] = BlockCols;

        igebs2d_(&context, "All", " ", &iTWO, &iONE, binfo, &iTWO);
    } else {
        igebr2d_(&context, "All", " ", &iTWO, &iONE, binfo, &iTWO,
                 &rootrow, &rootcol);
    }
    BlockRows  = binfo[0];
    BlockCols  = binfo[1];

    /* Retrieve matrix global dimensions */
    int minfo[2];
    minfo[0] = LocalRows; minfo[1] = LocalCols;
    igsum2d_(&context, "Col", " ", &iONE, &iONE, minfo, &iONE, &imONE, &imONE);
    igsum2d_(&context, "Row", " ", &iONE, &iONE, minfo+1, &iONE, &imONE, &imONE);
    GlobalRows = minfo[0]; GlobalCols = minfo[1];

    
    /* Reserve space on root */
    if (iamroot) {
        GlobalMatrixVector.resize(GlobalRows*GlobalCols);
        GlobalMatrix = &GlobalMatrixVector[0];
    }
        
    /* Gather matrix */
    int srcr = 0, srcc = 0;
    int SendRows, SendCols;
    int StartRow = 0, StartCol = 0;
    for (int r = 0; r < GlobalRows; r += BlockRows, srcr=(srcr+1)%procrows) {
        srcc = 0;
        
        // Is this the last row bloc?
        SendRows = BlockRows;
        if (GlobalRows-r < BlockRows)
            SendRows = GlobalRows-r;
        if (SendRows <= 0)
            SendRows = 0;
        
        for (int c=0; c<GlobalCols; c+=BlockCols, srcc=(srcc+1)%proccols) {
            
            // Is this the last column block?
            SendCols = BlockCols;
            if (GlobalCols-c < BlockCols)
                SendCols = GlobalCols-c;
            
            // Send data
            if (myrow == srcr && mycol == srcc) {
                FUNCNAME(gesd2d) (&context, &SendRows, &SendCols,
                         LocalMatrix+LocalRows*StartCol+StartRow,
                         &LocalRows, &rootrow, &rootcol
                );

                // Adjust the next starting column
                StartCol = (StartCol + SendCols) % LocalCols;
            }
            
            // Receive data
            if (iamroot) {
                FUNCNAME(gerv2d) (&context, &SendRows, &SendCols,
                         GlobalMatrix + GlobalRows*c + r,
                         &GlobalRows, &srcr, &srcc
                );
            }
        }
        
        // Adjust the next starting row
        if (myrow == srcr)
            StartRow = (StartRow + SendRows) % LocalRows;
        
    }
    
}