OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / gonum / dlarfb.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package gonum
6
7 import (
8         "gonum.org/v1/gonum/blas"
9         "gonum.org/v1/gonum/blas/blas64"
10         "gonum.org/v1/gonum/lapack"
11 )
12
13 // Dlarfb applies a block reflector to a matrix.
14 //
15 // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
16 //  c = h * c if side == Left and trans == NoTrans
17 //  c = c * h if side == Right and trans == NoTrans
18 //  c = h^T * c if side == Left and trans == Trans
19 //  c = c * h^T if side == Right and trans == Trans
20 // h is a product of elementary reflectors. direct sets the direction of multiplication
21 //  h = h_1 * h_2 * ... * h_k if direct == Forward
22 //  h = h_k * h_k-1 * ... * h_1 if direct == Backward
23 // The combination of direct and store defines the orientation of the elementary
24 // reflectors. In all cases the ones on the diagonal are implicitly represented.
25 //
26 // If direct == lapack.Forward and store == lapack.ColumnWise
27 //  V = [ 1        ]
28 //      [v1   1    ]
29 //      [v1  v2   1]
30 //      [v1  v2  v3]
31 //      [v1  v2  v3]
32 // If direct == lapack.Forward and store == lapack.RowWise
33 //  V = [ 1  v1  v1  v1  v1]
34 //      [     1  v2  v2  v2]
35 //      [         1  v3  v3]
36 // If direct == lapack.Backward and store == lapack.ColumnWise
37 //  V = [v1  v2  v3]
38 //      [v1  v2  v3]
39 //      [ 1  v2  v3]
40 //      [     1  v3]
41 //      [         1]
42 // If direct == lapack.Backward and store == lapack.RowWise
43 //  V = [v1  v1   1        ]
44 //      [v2  v2  v2   1    ]
45 //      [v3  v3  v3  v3   1]
46 // An elementary reflector can be explicitly constructed by extracting the
47 // corresponding elements of v, placing a 1 where the diagonal would be, and
48 // placing zeros in the remaining elements.
49 //
50 // t is a k×k matrix containing the block reflector, and this function will panic
51 // if t is not of sufficient size. See Dlarft for more information.
52 //
53 // work is a temporary storage matrix with stride ldwork.
54 // work must be of size at least n×k side == Left and m×k if side == Right, and
55 // this function will panic if this size is not met.
56 //
57 // Dlarfb is an internal routine. It is exported for testing purposes.
58 func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
59         if side != blas.Left && side != blas.Right {
60                 panic(badSide)
61         }
62         if trans != blas.Trans && trans != blas.NoTrans {
63                 panic(badTrans)
64         }
65         if direct != lapack.Forward && direct != lapack.Backward {
66                 panic(badDirect)
67         }
68         if store != lapack.ColumnWise && store != lapack.RowWise {
69                 panic(badStore)
70         }
71         checkMatrix(m, n, c, ldc)
72         if k < 0 {
73                 panic(kLT0)
74         }
75         checkMatrix(k, k, t, ldt)
76         nv := m
77         nw := n
78         if side == blas.Right {
79                 nv = n
80                 nw = m
81         }
82         if store == lapack.ColumnWise {
83                 checkMatrix(nv, k, v, ldv)
84         } else {
85                 checkMatrix(k, nv, v, ldv)
86         }
87         checkMatrix(nw, k, work, ldwork)
88
89         if m == 0 || n == 0 {
90                 return
91         }
92
93         bi := blas64.Implementation()
94
95         transt := blas.Trans
96         if trans == blas.Trans {
97                 transt = blas.NoTrans
98         }
99         // TODO(btracey): This follows the original Lapack code where the
100         // elements are copied into the columns of the working array. The
101         // loops should go in the other direction so the data is written
102         // into the rows of work so the copy is not strided. A bigger change
103         // would be to replace work with work^T, but benchmarks would be
104         // needed to see if the change is merited.
105         if store == lapack.ColumnWise {
106                 if direct == lapack.Forward {
107                         // V1 is the first k rows of C. V2 is the remaining rows.
108                         if side == blas.Left {
109                                 // W = C^T V = C1^T V1 + C2^T V2 (stored in work).
110
111                                 // W = C1.
112                                 for j := 0; j < k; j++ {
113                                         bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
114                                 }
115                                 // W = W * V1.
116                                 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
117                                         n, k, 1,
118                                         v, ldv,
119                                         work, ldwork)
120                                 if m > k {
121                                         // W = W + C2^T V2.
122                                         bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
123                                                 1, c[k*ldc:], ldc, v[k*ldv:], ldv,
124                                                 1, work, ldwork)
125                                 }
126                                 // W = W * T^T or W * T.
127                                 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
128                                         1, t, ldt,
129                                         work, ldwork)
130                                 // C -= V * W^T.
131                                 if m > k {
132                                         // C2 -= V2 * W^T.
133                                         bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
134                                                 -1, v[k*ldv:], ldv, work, ldwork,
135                                                 1, c[k*ldc:], ldc)
136                                 }
137                                 // W *= V1^T.
138                                 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
139                                         1, v, ldv,
140                                         work, ldwork)
141                                 // C1 -= W^T.
142                                 // TODO(btracey): This should use blas.Axpy.
143                                 for i := 0; i < n; i++ {
144                                         for j := 0; j < k; j++ {
145                                                 c[j*ldc+i] -= work[i*ldwork+j]
146                                         }
147                                 }
148                                 return
149                         }
150                         // Form C = C * H or C * H^T, where C = (C1 C2).
151
152                         // W = C1.
153                         for i := 0; i < k; i++ {
154                                 bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
155                         }
156                         // W *= V1.
157                         bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
158                                 1, v, ldv,
159                                 work, ldwork)
160                         if n > k {
161                                 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
162                                         1, c[k:], ldc, v[k*ldv:], ldv,
163                                         1, work, ldwork)
164                         }
165                         // W *= T or T^T.
166                         bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
167                                 1, t, ldt,
168                                 work, ldwork)
169                         if n > k {
170                                 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
171                                         -1, work, ldwork, v[k*ldv:], ldv,
172                                         1, c[k:], ldc)
173                         }
174                         // C -= W * V^T.
175                         bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
176                                 1, v, ldv,
177                                 work, ldwork)
178                         // C -= W.
179                         // TODO(btracey): This should use blas.Axpy.
180                         for i := 0; i < m; i++ {
181                                 for j := 0; j < k; j++ {
182                                         c[i*ldc+j] -= work[i*ldwork+j]
183                                 }
184                         }
185                         return
186                 }
187                 // V = (V1)
188                 //   = (V2) (last k rows)
189                 // Where V2 is unit upper triangular.
190                 if side == blas.Left {
191                         // Form H * C or
192                         // W = C^T V.
193
194                         // W = C2^T.
195                         for j := 0; j < k; j++ {
196                                 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
197                         }
198                         // W *= V2.
199                         bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
200                                 1, v[(m-k)*ldv:], ldv,
201                                 work, ldwork)
202                         if m > k {
203                                 // W += C1^T * V1.
204                                 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
205                                         1, c, ldc, v, ldv,
206                                         1, work, ldwork)
207                         }
208                         // W *= T or T^T.
209                         bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
210                                 1, t, ldt,
211                                 work, ldwork)
212                         // C -= V * W^T.
213                         if m > k {
214                                 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
215                                         -1, v, ldv, work, ldwork,
216                                         1, c, ldc)
217                         }
218                         // W *= V2^T.
219                         bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
220                                 1, v[(m-k)*ldv:], ldv,
221                                 work, ldwork)
222                         // C2 -= W^T.
223                         // TODO(btracey): This should use blas.Axpy.
224                         for i := 0; i < n; i++ {
225                                 for j := 0; j < k; j++ {
226                                         c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
227                                 }
228                         }
229                         return
230                 }
231                 // Form C * H or C * H^T where C = (C1 C2).
232                 // W = C * V.
233
234                 // W = C2.
235                 for j := 0; j < k; j++ {
236                         bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
237                 }
238
239                 // W = W * V2.
240                 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
241                         1, v[(n-k)*ldv:], ldv,
242                         work, ldwork)
243                 if n > k {
244                         bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
245                                 1, c, ldc, v, ldv,
246                                 1, work, ldwork)
247                 }
248                 // W *= T or T^T.
249                 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
250                         1, t, ldt,
251                         work, ldwork)
252                 // C -= W * V^T.
253                 if n > k {
254                         // C1 -= W * V1^T.
255                         bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
256                                 -1, work, ldwork, v, ldv,
257                                 1, c, ldc)
258                 }
259                 // W *= V2^T.
260                 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
261                         1, v[(n-k)*ldv:], ldv,
262                         work, ldwork)
263                 // C2 -= W.
264                 // TODO(btracey): This should use blas.Axpy.
265                 for i := 0; i < m; i++ {
266                         for j := 0; j < k; j++ {
267                                 c[i*ldc+n-k+j] -= work[i*ldwork+j]
268                         }
269                 }
270                 return
271         }
272         // Store = Rowwise.
273         if direct == lapack.Forward {
274                 // V = (V1 V2) where v1 is unit upper triangular.
275                 if side == blas.Left {
276                         // Form H * C or H^T * C where C = (C1; C2).
277                         // W = C^T * V^T.
278
279                         // W = C1^T.
280                         for j := 0; j < k; j++ {
281                                 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
282                         }
283                         // W *= V1^T.
284                         bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
285                                 1, v, ldv,
286                                 work, ldwork)
287                         if m > k {
288                                 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
289                                         1, c[k*ldc:], ldc, v[k:], ldv,
290                                         1, work, ldwork)
291                         }
292                         // W *= T or T^T.
293                         bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
294                                 1, t, ldt,
295                                 work, ldwork)
296                         // C -= V^T * W^T.
297                         if m > k {
298                                 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
299                                         -1, v[k:], ldv, work, ldwork,
300                                         1, c[k*ldc:], ldc)
301                         }
302                         // W *= V1.
303                         bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
304                                 1, v, ldv,
305                                 work, ldwork)
306                         // C1 -= W^T.
307                         // TODO(btracey): This should use blas.Axpy.
308                         for i := 0; i < n; i++ {
309                                 for j := 0; j < k; j++ {
310                                         c[j*ldc+i] -= work[i*ldwork+j]
311                                 }
312                         }
313                         return
314                 }
315                 // Form C * H or C * H^T where C = (C1 C2).
316                 // W = C * V^T.
317
318                 // W = C1.
319                 for j := 0; j < k; j++ {
320                         bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
321                 }
322                 // W *= V1^T.
323                 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
324                         1, v, ldv,
325                         work, ldwork)
326                 if n > k {
327                         bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
328                                 1, c[k:], ldc, v[k:], ldv,
329                                 1, work, ldwork)
330                 }
331                 // W *= T or T^T.
332                 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
333                         1, t, ldt,
334                         work, ldwork)
335                 // C -= W * V.
336                 if n > k {
337                         bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
338                                 -1, work, ldwork, v[k:], ldv,
339                                 1, c[k:], ldc)
340                 }
341                 // W *= V1.
342                 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
343                         1, v, ldv,
344                         work, ldwork)
345                 // C1 -= W.
346                 // TODO(btracey): This should use blas.Axpy.
347                 for i := 0; i < m; i++ {
348                         for j := 0; j < k; j++ {
349                                 c[i*ldc+j] -= work[i*ldwork+j]
350                         }
351                 }
352                 return
353         }
354         // V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
355         if side == blas.Left {
356                 // Form H * C or H^T C where C = (C1 ; C2).
357                 // W = C^T * V^T.
358
359                 // W = C2^T.
360                 for j := 0; j < k; j++ {
361                         bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
362                 }
363                 // W *= V2^T.
364                 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
365                         1, v[m-k:], ldv,
366                         work, ldwork)
367                 if m > k {
368                         bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
369                                 1, c, ldc, v, ldv,
370                                 1, work, ldwork)
371                 }
372                 // W *= T or T^T.
373                 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
374                         1, t, ldt,
375                         work, ldwork)
376                 // C -= V^T * W^T.
377                 if m > k {
378                         bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
379                                 -1, v, ldv, work, ldwork,
380                                 1, c, ldc)
381                 }
382                 // W *= V2.
383                 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
384                         1, v[m-k:], ldv,
385                         work, ldwork)
386                 // C2 -= W^T.
387                 // TODO(btracey): This should use blas.Axpy.
388                 for i := 0; i < n; i++ {
389                         for j := 0; j < k; j++ {
390                                 c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
391                         }
392                 }
393                 return
394         }
395         // Form C * H or C * H^T where C = (C1 C2).
396         // W = C * V^T.
397         // W = C2.
398         for j := 0; j < k; j++ {
399                 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
400         }
401         // W *= V2^T.
402         bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
403                 1, v[n-k:], ldv,
404                 work, ldwork)
405         if n > k {
406                 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
407                         1, c, ldc, v, ldv,
408                         1, work, ldwork)
409         }
410         // W *= T or T^T.
411         bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
412                 1, t, ldt,
413                 work, ldwork)
414         // C -= W * V.
415         if n > k {
416                 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
417                         -1, work, ldwork, v, ldv,
418                         1, c, ldc)
419         }
420         // W *= V2.
421         bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
422                 1, v[n-k:], ldv,
423                 work, ldwork)
424         // C1 -= W.
425         // TODO(btracey): This should use blas.Axpy.
426         for i := 0; i < m; i++ {
427                 for j := 0; j < k; j++ {
428                         c[i*ldc+n-k+j] -= work[i*ldwork+j]
429                 }
430         }
431 }