View Javadoc

1   
2   package org.catacomb.numeric.math;
3   
4   
5   
6   public abstract class DiagonalBlockMatrix {
7   
8       private final static double abs(double d) {
9           return Math.abs(d);
10      }
11  
12  
13      private final static double max(double a, double b) {
14          return (a > b ? a : b);
15      }
16  
17  
18      public static void main(String[] argv) {
19          dbmTest();
20      }
21  
22      public static void Sp(String s) {
23          System.out.println(s);
24      }
25  
26  
27      public static final void dbmTest() {
28          /*
29            on laptop with nv = 6, nm=100, ne=1 takes about 95 ms,
30            compared to 25 ms, for g77 version,  7 ms for g77 -O version.
31  
32           */
33  
34          int nv = 8;
35          int nm = 99;
36          int ne = 1;
37          double[][][] a = new double[nm][3*nv+ne+1][nv+4];
38          double[][][] b = new double[nm][3*nv+ne+1][nv+4];
39          int[] nrpb = new int[nm];
40  
41          int ii = 1234;
42  
43          for (int i = 0; i < nm; i++) {
44              for (int j = 0; j < nv+4; j++) {
45                  for (int k = 0; k < 3*nv+ne+1; k++) {
46                      ii = 7 * ii;
47                      ii = ii % 14567;
48                      double ddd = 0.00345 * ii;
49                      a[i][k][j] = ddd; // Math.random();
50                      b[i][k][j] = a[i][k][j];
51                  }
52              }
53              nrpb[i] = nv;
54          }
55  
56          //    nrpb[nm-1] += ne;
57  
58  
59          nrpb[0] -= 2;
60          nrpb[nm-1] += 3;
61  
62  
63          long ttt = System.currentTimeMillis();
64          double[] corr = dbmSolve(nv, nm, ne, nrpb, a);
65          ttt = System.currentTimeMillis() - ttt;
66  
67          double maxdev = 0.0;
68  
69          for (int i = 0; i < nm; i++) {
70              int k0 = i-1;
71              if (k0 < 0) {
72                  k0 = 0;
73              }
74              int kk = 3;
75              if (i <= 0 || i == nm-1) {
76                  kk = 2;
77              }
78  
79              for (int j = 0; j < nrpb[i]; j++) {
80                  double v = 0.0;
81                  for (int k = 0; k < kk * nv; k++) {
82                      v += b[i][k][j] * corr[nv*k0 + k];
83                  }
84                  for (int k = 0; k < ne; k++) {
85                      v += b[i][kk*nv+k][j] * corr[nv*nm+k];
86                  }
87                  double dev = v - b[i][kk*nv+ne][j];
88                  dev = Math.abs(dev);
89                  if (dev > maxdev) {
90                      maxdev = dev;
91                  }
92                  if (dev > 1.e-3) {
93                      Sp("  " + i + " " + j + " " + v + " " +
94                         b[i][kk*nv+ne][j]);
95                  }
96              }
97          }
98  
99          if (maxdev > 0.01) {
100             int nl = corr.length;
101             Sp("corr elts " + corr[nl-1] + " " + corr[nl-2]);
102 
103         }
104 
105         Sp("max deviation " + maxdev);
106         Sp(" calc time " + ttt);
107     }
108 
109 
110 
111 
112 
113 
114     public static final double[] dbmSolve(int nm, int nv, int ne,
115                                           int[] nrpb, double[][][] a) {
116 
117         int[] nnz = new int[nm];
118         double[] corr = new double[nv * nm + ne];
119 
120 
121         int ncollx = 3 * nv + ne + 1;
122 
123         int[] nelim = new int[nm];
124         nelim[0] = 0;
125         nelim[1] = nrpb[0];
126         for (int k = 2; k < nm; k++) {
127             nelim[k] = nelim[k-1] + nrpb[k-1] - nv;
128         }
129 
130         /*
131         for (int k = 1; k < nm; k++) {
132           if (nelim[k] > nrpb[k-1]) {
133         S.p ("too few rows in block above for elim");
134         U.dumpArray ("nrpb ", nrpb);
135         U.dumpArray ("nelim ", nelim);
136           }
137         }
138         */
139 
140         // blocks 1 to nm-1
141         for (int k = 0; k < nm; k++) {
142             int ncoll = ncollx;
143             if (k == 0 || k == nm-1) {
144                 ncoll -= nv;
145             }
146             double[][] b = a[k];
147             int nel = nelim[k];
148 
149             // eliminate nel columns from this block using the block above;
150             if (nel > 0) {
151                 double[][] s = a[k-1];
152                 int novlp = nnz[k-1] - ne - 1;
153                 int nu = nrpb[k-1] - nel;
154                 int ns = nnz[k-1];
155 
156                 for (int ic = 0; ic < nel; ic++) {
157                     for (int ir = 0; ir < nrpb[k]; ir++) {
158                         double f = b[ic][ir];
159                         for (int j = 0; j < novlp; j++) {
160                             b[j+nel][ir] -= f * s[j][nu+ic];
161                         }
162                         // do the same to 'e-value' columns and r.h.s. vector
163                         for (int j = 1; j < ne+2; j++) {
164                             b[ncoll-j][ir] -= f *s[ns-j][nu+ic];
165                         }
166                     }
167                 }
168             }
169 
170             for (int ic = 0; ic < ncoll-nel; ic++) {
171                 b[ic] = b[nel+ic];
172             }
173             diag(nrpb[k], ncoll-nel, a[k], k);
174             nnz[k] = ncoll - nelim[k] - nrpb[k];
175             for (int ic = 0; ic < nnz[k]; ic++) {
176                 b[ic] = b[ic + nrpb[k]];
177             }
178             for (int ic = nnz[k]; ic < ncoll; ic++) {
179                 b[ic] = null;
180             }
181         }
182 
183         if (nnz[nm-1] != 1) {
184             Sp("solve error " + nnz[nm]);
185         }
186 
187 
188         // backsubstitution
189         int l = nv * nm + ne;
190         for (int k = nm-1; k >= 0; k--) {
191             double[][] b = a[k];
192             int nz = nnz[k];
193             int nr = nrpb[k];
194             for (int i = 1; i <= nr; i++) {
195                 double c = b[nz-1][nr-i];
196 
197                 // terms involving e-values
198                 for (int j = 1; j <= ne && j < nz; j++) {
199                     c -= b[nz-j-1][nr-i] * corr[nv*nm+ne-j];
200                 }
201                 for (int j = 0; j < nz -1 - ne; j++) {
202                     c -= b[j][nr-i] * corr[l+j];
203                 }
204                 corr[l-i] = c;
205             }
206             l -= nrpb[k];
207         }
208 
209         return corr;
210     }
211 
212 
213 
214 
215     public static final int diag(int nr, int nc, double[][] s, int block) {
216         // block here just for error reporting;
217 
218         double[] rn = new double[nr];
219         double[] pivr = new double[nc];
220 
221         // record scalings for implicit pivoting
222         for (int ir = 0; ir < nr; ir++) {
223             rn[ir] = 0.;
224             for (int ic = 0; ic < nc-1; ic++) {
225                 rn[ir] = max(abs(s[ic][ir]), rn[ir]);
226             }
227             if (rn[ir] <= 0.0) {
228                 Sp("row sum 0 in block " + block + "  row " + ir);
229                 return -1;
230             }
231             rn[ir] = 1. / rn[ir];
232         }
233 
234         // diagonalise left hand end
235         for (int ir = 0; ir < nr; ir++) {
236             // choose pivot
237             int k = ir;
238             double mx = abs(s[ir][ir] * rn[ir]);
239             for (int l = ir+1; l < nr; l++) {
240                 double v = abs(rn[l] * s[ir][l]);
241                 if (v > mx) {
242                     mx = v;
243                     k = l;
244                 }
245             }
246             if (s[ir][k] == 0) {
247                 Sp("no pivot in block " + block + " for row " + ir);
248                 return -1;
249             }
250 
251             double f = 1.0 / s[ir][k];
252             for (int i = 0; i < nc; i++) {
253                 pivr[i] = s[i][k] * f;
254                 s[i][k] = s[i][ir];
255             }
256             rn[k] = rn[ir];  //----------? overwrites?
257 
258             // eliminate elements below pivot
259             for (int i = ir+1; i < nr; i++) {
260                 double g = s[ir][i];
261                 for (int j = ir; j < nc; j++) {
262                     s[j][i] -= g * pivr[j];
263                 }
264             }
265 
266             // slot in pivot row
267             for (int j = ir+1; j < nc; j++) {
268                 s[j][ir] = pivr[j];
269             }
270         }
271 
272         // eliminate elements above diagonal
273         for (int ir = nr-2; ir >= 0; ir--) {
274             for (int ic = ir+1; ic < nr; ic++) {
275                 double g = s[ic][ir];
276                 for (int j = nr; j < nc; j++) {
277                     s[j][ir] -= g * s[j][ic];
278                 }
279             }
280         }
281 
282         return 0;
283     }
284 
285 
286 
287 
288 
289 
290 
291 
292 }
293 
294