1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
8 "gonum.org/v1/gonum/blas"
9 "gonum.org/v1/gonum/blas/blas64"
10 "gonum.org/v1/gonum/lapack"
13 // Dlarfb applies a block reflector to a matrix.
15 // In the call to Dlarfb, the mxn c is multiplied by the implicitly defined matrix h as follows:
16 // c = h * c if side == Left and trans == NoTrans
17 // c = c * h if side == Right and trans == NoTrans
18 // c = h^T * c if side == Left and trans == Trans
19 // c = c * h^T if side == Right and trans == Trans
20 // h is a product of elementary reflectors. direct sets the direction of multiplication
21 // h = h_1 * h_2 * ... * h_k if direct == Forward
22 // h = h_k * h_k-1 * ... * h_1 if direct == Backward
23 // The combination of direct and store defines the orientation of the elementary
24 // reflectors. In all cases the ones on the diagonal are implicitly represented.
26 // If direct == lapack.Forward and store == lapack.ColumnWise
32 // If direct == lapack.Forward and store == lapack.RowWise
33 // V = [ 1 v1 v1 v1 v1]
36 // If direct == lapack.Backward and store == lapack.ColumnWise
42 // If direct == lapack.Backward and store == lapack.RowWise
46 // An elementary reflector can be explicitly constructed by extracting the
47 // corresponding elements of v, placing a 1 where the diagonal would be, and
48 // placing zeros in the remaining elements.
50 // t is a k×k matrix containing the block reflector, and this function will panic
51 // if t is not of sufficient size. See Dlarft for more information.
53 // work is a temporary storage matrix with stride ldwork.
54 // work must be of size at least n×k side == Left and m×k if side == Right, and
55 // this function will panic if this size is not met.
57 // Dlarfb is an internal routine. It is exported for testing purposes.
58 func (Implementation) Dlarfb(side blas.Side, trans blas.Transpose, direct lapack.Direct, store lapack.StoreV, m, n, k int, v []float64, ldv int, t []float64, ldt int, c []float64, ldc int, work []float64, ldwork int) {
59 if side != blas.Left && side != blas.Right {
62 if trans != blas.Trans && trans != blas.NoTrans {
65 if direct != lapack.Forward && direct != lapack.Backward {
68 if store != lapack.ColumnWise && store != lapack.RowWise {
71 checkMatrix(m, n, c, ldc)
75 checkMatrix(k, k, t, ldt)
78 if side == blas.Right {
82 if store == lapack.ColumnWise {
83 checkMatrix(nv, k, v, ldv)
85 checkMatrix(k, nv, v, ldv)
87 checkMatrix(nw, k, work, ldwork)
93 bi := blas64.Implementation()
96 if trans == blas.Trans {
99 // TODO(btracey): This follows the original Lapack code where the
100 // elements are copied into the columns of the working array. The
101 // loops should go in the other direction so the data is written
102 // into the rows of work so the copy is not strided. A bigger change
103 // would be to replace work with work^T, but benchmarks would be
104 // needed to see if the change is merited.
105 if store == lapack.ColumnWise {
106 if direct == lapack.Forward {
107 // V1 is the first k rows of C. V2 is the remaining rows.
108 if side == blas.Left {
109 // W = C^T V = C1^T V1 + C2^T V2 (stored in work).
112 for j := 0; j < k; j++ {
113 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
116 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit,
122 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
123 1, c[k*ldc:], ldc, v[k*ldv:], ldv,
126 // W = W * T^T or W * T.
127 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
133 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
134 -1, v[k*ldv:], ldv, work, ldwork,
138 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
142 // TODO(btracey): This should use blas.Axpy.
143 for i := 0; i < n; i++ {
144 for j := 0; j < k; j++ {
145 c[j*ldc+i] -= work[i*ldwork+j]
150 // Form C = C * H or C * H^T, where C = (C1 C2).
153 for i := 0; i < k; i++ {
154 bi.Dcopy(m, c[i:], ldc, work[i:], ldwork)
157 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
161 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
162 1, c[k:], ldc, v[k*ldv:], ldv,
166 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
170 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
171 -1, work, ldwork, v[k*ldv:], ldv,
175 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
179 // TODO(btracey): This should use blas.Axpy.
180 for i := 0; i < m; i++ {
181 for j := 0; j < k; j++ {
182 c[i*ldc+j] -= work[i*ldwork+j]
188 // = (V2) (last k rows)
189 // Where V2 is unit upper triangular.
190 if side == blas.Left {
195 for j := 0; j < k; j++ {
196 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
199 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
200 1, v[(m-k)*ldv:], ldv,
204 bi.Dgemm(blas.Trans, blas.NoTrans, n, k, m-k,
209 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
214 bi.Dgemm(blas.NoTrans, blas.Trans, m-k, n, k,
215 -1, v, ldv, work, ldwork,
219 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
220 1, v[(m-k)*ldv:], ldv,
223 // TODO(btracey): This should use blas.Axpy.
224 for i := 0; i < n; i++ {
225 for j := 0; j < k; j++ {
226 c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
231 // Form C * H or C * H^T where C = (C1 C2).
235 for j := 0; j < k; j++ {
236 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
240 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
241 1, v[(n-k)*ldv:], ldv,
244 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, k, n-k,
249 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
255 bi.Dgemm(blas.NoTrans, blas.Trans, m, n-k, k,
256 -1, work, ldwork, v, ldv,
260 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
261 1, v[(n-k)*ldv:], ldv,
264 // TODO(btracey): This should use blas.Axpy.
265 for i := 0; i < m; i++ {
266 for j := 0; j < k; j++ {
267 c[i*ldc+n-k+j] -= work[i*ldwork+j]
273 if direct == lapack.Forward {
274 // V = (V1 V2) where v1 is unit upper triangular.
275 if side == blas.Left {
276 // Form H * C or H^T * C where C = (C1; C2).
280 for j := 0; j < k; j++ {
281 bi.Dcopy(n, c[j*ldc:], 1, work[j:], ldwork)
284 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, n, k,
288 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
289 1, c[k*ldc:], ldc, v[k:], ldv,
293 bi.Dtrmm(blas.Right, blas.Upper, transt, blas.NonUnit, n, k,
298 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
299 -1, v[k:], ldv, work, ldwork,
303 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, n, k,
307 // TODO(btracey): This should use blas.Axpy.
308 for i := 0; i < n; i++ {
309 for j := 0; j < k; j++ {
310 c[j*ldc+i] -= work[i*ldwork+j]
315 // Form C * H or C * H^T where C = (C1 C2).
319 for j := 0; j < k; j++ {
320 bi.Dcopy(m, c[j:], ldc, work[j:], ldwork)
323 bi.Dtrmm(blas.Right, blas.Upper, blas.Trans, blas.Unit, m, k,
327 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
328 1, c[k:], ldc, v[k:], ldv,
332 bi.Dtrmm(blas.Right, blas.Upper, trans, blas.NonUnit, m, k,
337 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
338 -1, work, ldwork, v[k:], ldv,
342 bi.Dtrmm(blas.Right, blas.Upper, blas.NoTrans, blas.Unit, m, k,
346 // TODO(btracey): This should use blas.Axpy.
347 for i := 0; i < m; i++ {
348 for j := 0; j < k; j++ {
349 c[i*ldc+j] -= work[i*ldwork+j]
354 // V = (V1 V2) where V2 is the last k columns and is lower unit triangular.
355 if side == blas.Left {
356 // Form H * C or H^T C where C = (C1 ; C2).
360 for j := 0; j < k; j++ {
361 bi.Dcopy(n, c[(m-k+j)*ldc:], 1, work[j:], ldwork)
364 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, n, k,
368 bi.Dgemm(blas.Trans, blas.Trans, n, k, m-k,
373 bi.Dtrmm(blas.Right, blas.Lower, transt, blas.NonUnit, n, k,
378 bi.Dgemm(blas.Trans, blas.Trans, m-k, n, k,
379 -1, v, ldv, work, ldwork,
383 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, n, k,
387 // TODO(btracey): This should use blas.Axpy.
388 for i := 0; i < n; i++ {
389 for j := 0; j < k; j++ {
390 c[(m-k+j)*ldc+i] -= work[i*ldwork+j]
395 // Form C * H or C * H^T where C = (C1 C2).
398 for j := 0; j < k; j++ {
399 bi.Dcopy(m, c[n-k+j:], ldc, work[j:], ldwork)
402 bi.Dtrmm(blas.Right, blas.Lower, blas.Trans, blas.Unit, m, k,
406 bi.Dgemm(blas.NoTrans, blas.Trans, m, k, n-k,
411 bi.Dtrmm(blas.Right, blas.Lower, trans, blas.NonUnit, m, k,
416 bi.Dgemm(blas.NoTrans, blas.NoTrans, m, n-k, k,
417 -1, work, ldwork, v, ldv,
421 bi.Dtrmm(blas.Right, blas.Lower, blas.NoTrans, blas.Unit, m, k,
425 // TODO(btracey): This should use blas.Axpy.
426 for i := 0; i < m; i++ {
427 for j := 0; j < k; j++ {
428 c[i*ldc+n-k+j] -= work[i*ldwork+j]