1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
8 "gonum.org/v1/gonum/blas"
9 "gonum.org/v1/gonum/internal/asm/f64"
12 var _ blas.Float64Level3 = Implementation{}
15 // A * X = alpha * B, if tA == blas.NoTrans side == blas.Left,
16 // A^T * X = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
17 // X * A = alpha * B, if tA == blas.NoTrans side == blas.Right,
18 // X * A^T = alpha * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
19 // where A is an n×n or m×m triangular matrix, X is an m×n matrix, and alpha is a
22 // At entry to the function, X contains the values of B, and the result is
23 // stored in place into X.
25 // No check is made that A is invertible.
26 func (Implementation) Dtrsm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
27 if s != blas.Left && s != blas.Right {
30 if ul != blas.Lower && ul != blas.Upper {
33 if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
36 if d != blas.NonUnit && d != blas.Unit {
54 if lda*(k-1)+k > len(a) || lda < max(1, k) {
57 if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
66 for i := 0; i < m; i++ {
67 btmp := b[i*ldb : i*ldb+n]
74 nonUnit := d == blas.NonUnit
76 if tA == blas.NoTrans {
78 for i := m - 1; i >= 0; i-- {
79 btmp := b[i*ldb : i*ldb+n]
85 for ka, va := range a[i*lda+i+1 : i*lda+m] {
88 f64.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
93 for j := 0; j < n; j++ {
100 for i := 0; i < m; i++ {
101 btmp := b[i*ldb : i*ldb+n]
103 for j := 0; j < n; j++ {
107 for k, va := range a[i*lda : i*lda+i] {
109 f64.AxpyUnitaryTo(btmp, -va, b[k*ldb:k*ldb+n], btmp)
113 tmp := 1 / a[i*lda+i]
114 for j := 0; j < n; j++ {
121 // Cases where a is transposed
122 if ul == blas.Upper {
123 for k := 0; k < m; k++ {
124 btmpk := b[k*ldb : k*ldb+n]
126 tmp := 1 / a[k*lda+k]
127 for j := 0; j < n; j++ {
131 for ia, va := range a[k*lda+k+1 : k*lda+m] {
134 btmp := b[i*ldb : i*ldb+n]
135 f64.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
139 for j := 0; j < n; j++ {
146 for k := m - 1; k >= 0; k-- {
147 btmpk := b[k*ldb : k*ldb+n]
149 tmp := 1 / a[k*lda+k]
150 for j := 0; j < n; j++ {
154 for i, va := range a[k*lda : k*lda+k] {
156 btmp := b[i*ldb : i*ldb+n]
157 f64.AxpyUnitaryTo(btmp, -va, btmpk, btmp)
161 for j := 0; j < n; j++ {
168 // Cases where a is to the right of X.
169 if tA == blas.NoTrans {
170 if ul == blas.Upper {
171 for i := 0; i < m; i++ {
172 btmp := b[i*ldb : i*ldb+n]
174 for j := 0; j < n; j++ {
178 for k, vb := range btmp {
182 btmp[k] /= a[k*lda+k]
184 btmpk := btmp[k+1 : n]
185 f64.AxpyUnitaryTo(btmpk, -btmp[k], a[k*lda+k+1:k*lda+n], btmpk)
192 for i := 0; i < m; i++ {
193 btmp := b[i*lda : i*lda+n]
195 for j := 0; j < n; j++ {
199 for k := n - 1; k >= 0; k-- {
202 btmp[k] /= a[k*lda+k]
204 f64.AxpyUnitaryTo(btmp, -btmp[k], a[k*lda:k*lda+k], btmp)
210 // Cases where a is transposed.
211 if ul == blas.Upper {
212 for i := 0; i < m; i++ {
213 btmp := b[i*lda : i*lda+n]
214 for j := n - 1; j >= 0; j-- {
215 tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:])
224 for i := 0; i < m; i++ {
225 btmp := b[i*lda : i*lda+n]
226 for j := 0; j < n; j++ {
227 tmp := alpha*btmp[j] - f64.DotUnitary(a[j*lda:j*lda+j], btmp)
236 // Dsymm performs one of
237 // C = alpha * A * B + beta * C, if side == blas.Left,
238 // C = alpha * B * A + beta * C, if side == blas.Right,
239 // where A is an n×n or m×m symmetric matrix, B and C are m×n matrices, and alpha
241 func (Implementation) Dsymm(s blas.Side, ul blas.Uplo, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
242 if s != blas.Right && s != blas.Left {
243 panic("goblas: bad side")
245 if ul != blas.Lower && ul != blas.Upper {
260 if lda*(k-1)+k > len(a) || lda < max(1, k) {
263 if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
266 if ldc*(m-1)+n > len(c) || ldc < max(1, n) {
269 if m == 0 || n == 0 {
272 if alpha == 0 && beta == 1 {
277 for i := 0; i < m; i++ {
278 ctmp := c[i*ldc : i*ldc+n]
279 for j := range ctmp {
285 for i := 0; i < m; i++ {
286 ctmp := c[i*ldc : i*ldc+n]
287 for j := 0; j < n; j++ {
294 isUpper := ul == blas.Upper
296 for i := 0; i < m; i++ {
297 atmp := alpha * a[i*lda+i]
298 btmp := b[i*ldb : i*ldb+n]
299 ctmp := c[i*ldc : i*ldc+n]
300 for j, v := range btmp {
305 for k := 0; k < i; k++ {
313 ctmp := c[i*ldc : i*ldc+n]
314 f64.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
316 for k := i + 1; k < m; k++ {
324 ctmp := c[i*ldc : i*ldc+n]
325 f64.AxpyUnitaryTo(ctmp, atmp, b[k*ldb:k*ldb+n], ctmp)
331 for i := 0; i < m; i++ {
332 for j := n - 1; j >= 0; j-- {
333 tmp := alpha * b[i*ldb+j]
335 atmp := a[j*lda+j+1 : j*lda+n]
336 btmp := b[i*ldb+j+1 : i*ldb+n]
337 ctmp := c[i*ldc+j+1 : i*ldc+n]
338 for k, v := range atmp {
343 c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
348 for i := 0; i < m; i++ {
349 for j := 0; j < n; j++ {
350 tmp := alpha * b[i*ldb+j]
352 atmp := a[j*lda : j*lda+j]
353 btmp := b[i*ldb : i*ldb+j]
354 ctmp := c[i*ldc : i*ldc+j]
355 for k, v := range atmp {
360 c[i*ldc+j] += tmp*a[j*lda+j] + alpha*tmp2
365 // Dsyrk performs the symmetric rank-k operation
366 // C = alpha * A * A^T + beta*C
367 // C is an n×n symmetric matrix. A is an n×k matrix if tA == blas.NoTrans, and
368 // a k×n matrix otherwise. alpha and beta are scalars.
369 func (Implementation) Dsyrk(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, beta float64, c []float64, ldc int) {
370 if ul != blas.Lower && ul != blas.Upper {
373 if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
386 if tA == blas.NoTrans {
391 if lda*(row-1)+col > len(a) || lda < max(1, col) {
394 if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
399 if ul == blas.Upper {
400 for i := 0; i < n; i++ {
401 ctmp := c[i*ldc+i : i*ldc+n]
402 for j := range ctmp {
408 for i := 0; i < n; i++ {
409 ctmp := c[i*ldc : i*ldc+i+1]
410 for j := range ctmp {
416 if ul == blas.Upper {
417 for i := 0; i < n; i++ {
418 ctmp := c[i*ldc+i : i*ldc+n]
419 for j := range ctmp {
425 for i := 0; i < n; i++ {
426 ctmp := c[i*ldc : i*ldc+i+1]
427 for j := range ctmp {
433 if tA == blas.NoTrans {
434 if ul == blas.Upper {
435 for i := 0; i < n; i++ {
436 ctmp := c[i*ldc+i : i*ldc+n]
437 atmp := a[i*lda : i*lda+k]
438 for jc, vc := range ctmp {
440 ctmp[jc] = vc*beta + alpha*f64.DotUnitary(atmp, a[j*lda:j*lda+k])
445 for i := 0; i < n; i++ {
446 atmp := a[i*lda : i*lda+k]
447 for j, vc := range c[i*ldc : i*ldc+i+1] {
448 c[i*ldc+j] = vc*beta + alpha*f64.DotUnitary(a[j*lda:j*lda+k], atmp)
453 // Cases where a is transposed.
454 if ul == blas.Upper {
455 for i := 0; i < n; i++ {
456 ctmp := c[i*ldc+i : i*ldc+n]
458 for j := range ctmp {
462 for l := 0; l < k; l++ {
463 tmp := alpha * a[l*lda+i]
465 f64.AxpyUnitaryTo(ctmp, tmp, a[l*lda+i:l*lda+n], ctmp)
471 for i := 0; i < n; i++ {
472 ctmp := c[i*ldc : i*ldc+i+1]
474 for j := range ctmp {
478 for l := 0; l < k; l++ {
479 tmp := alpha * a[l*lda+i]
481 f64.AxpyUnitaryTo(ctmp, tmp, a[l*lda:l*lda+i+1], ctmp)
487 // Dsyr2k performs the symmetric rank 2k operation
488 // C = alpha * A * B^T + alpha * B * A^T + beta * C
489 // where C is an n×n symmetric matrix. A and B are n×k matrices if
490 // tA == NoTrans and k×n otherwise. alpha and beta are scalars.
491 func (Implementation) Dsyr2k(ul blas.Uplo, tA blas.Transpose, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
492 if ul != blas.Lower && ul != blas.Upper {
495 if tA != blas.Trans && tA != blas.NoTrans && tA != blas.ConjTrans {
508 if tA == blas.NoTrans {
513 if lda*(row-1)+col > len(a) || lda < max(1, col) {
516 if ldb*(row-1)+col > len(b) || ldb < max(1, col) {
519 if ldc*(n-1)+n > len(c) || ldc < max(1, n) {
524 if ul == blas.Upper {
525 for i := 0; i < n; i++ {
526 ctmp := c[i*ldc+i : i*ldc+n]
527 for j := range ctmp {
533 for i := 0; i < n; i++ {
534 ctmp := c[i*ldc : i*ldc+i+1]
535 for j := range ctmp {
541 if ul == blas.Upper {
542 for i := 0; i < n; i++ {
543 ctmp := c[i*ldc+i : i*ldc+n]
544 for j := range ctmp {
550 for i := 0; i < n; i++ {
551 ctmp := c[i*ldc : i*ldc+i+1]
552 for j := range ctmp {
558 if tA == blas.NoTrans {
559 if ul == blas.Upper {
560 for i := 0; i < n; i++ {
561 atmp := a[i*lda : i*lda+k]
562 btmp := b[i*ldb : i*ldb+k]
563 ctmp := c[i*ldc+i : i*ldc+n]
564 for jc := range ctmp {
566 var tmp1, tmp2 float64
567 binner := b[j*ldb : j*ldb+k]
568 for l, v := range a[j*lda : j*lda+k] {
570 tmp2 += atmp[l] * binner[l]
573 ctmp[jc] += alpha * (tmp1 + tmp2)
578 for i := 0; i < n; i++ {
579 atmp := a[i*lda : i*lda+k]
580 btmp := b[i*ldb : i*ldb+k]
581 ctmp := c[i*ldc : i*ldc+i+1]
582 for j := 0; j <= i; j++ {
583 var tmp1, tmp2 float64
584 binner := b[j*ldb : j*ldb+k]
585 for l, v := range a[j*lda : j*lda+k] {
587 tmp2 += atmp[l] * binner[l]
590 ctmp[j] += alpha * (tmp1 + tmp2)
595 if ul == blas.Upper {
596 for i := 0; i < n; i++ {
597 ctmp := c[i*ldc+i : i*ldc+n]
599 for j := range ctmp {
603 for l := 0; l < k; l++ {
604 tmp1 := alpha * b[l*lda+i]
605 tmp2 := alpha * a[l*lda+i]
606 btmp := b[l*ldb+i : l*ldb+n]
607 if tmp1 != 0 || tmp2 != 0 {
608 for j, v := range a[l*lda+i : l*lda+n] {
609 ctmp[j] += v*tmp1 + btmp[j]*tmp2
616 for i := 0; i < n; i++ {
617 ctmp := c[i*ldc : i*ldc+i+1]
619 for j := range ctmp {
623 for l := 0; l < k; l++ {
624 tmp1 := alpha * b[l*lda+i]
625 tmp2 := alpha * a[l*lda+i]
626 btmp := b[l*ldb : l*ldb+i+1]
627 if tmp1 != 0 || tmp2 != 0 {
628 for j, v := range a[l*lda : l*lda+i+1] {
629 ctmp[j] += v*tmp1 + btmp[j]*tmp2
637 // B = alpha * A * B, if tA == blas.NoTrans and side == blas.Left,
638 // B = alpha * A^T * B, if tA == blas.Trans or blas.ConjTrans, and side == blas.Left,
639 // B = alpha * B * A, if tA == blas.NoTrans and side == blas.Right,
640 // B = alpha * B * A^T, if tA == blas.Trans or blas.ConjTrans, and side == blas.Right,
641 // where A is an n×n or m×m triangular matrix, and B is an m×n matrix.
642 func (Implementation) Dtrmm(s blas.Side, ul blas.Uplo, tA blas.Transpose, d blas.Diag, m, n int, alpha float64, a []float64, lda int, b []float64, ldb int) {
643 if s != blas.Left && s != blas.Right {
646 if ul != blas.Lower && ul != blas.Upper {
649 if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
652 if d != blas.NonUnit && d != blas.Unit {
667 if lda*(k-1)+k > len(a) || lda < max(1, k) {
670 if ldb*(m-1)+n > len(b) || ldb < max(1, n) {
674 for i := 0; i < m; i++ {
675 btmp := b[i*ldb : i*ldb+n]
676 for j := range btmp {
683 nonUnit := d == blas.NonUnit
685 if tA == blas.NoTrans {
686 if ul == blas.Upper {
687 for i := 0; i < m; i++ {
692 btmp := b[i*ldb : i*ldb+n]
693 for j := range btmp {
696 for ka, va := range a[i*lda+i+1 : i*lda+m] {
700 f64.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
706 for i := m - 1; i >= 0; i-- {
711 btmp := b[i*ldb : i*ldb+n]
712 for j := range btmp {
715 for k, va := range a[i*lda : i*lda+i] {
718 f64.AxpyUnitaryTo(btmp, tmp, b[k*ldb:k*ldb+n], btmp)
724 // Cases where a is transposed.
725 if ul == blas.Upper {
726 for k := m - 1; k >= 0; k-- {
727 btmpk := b[k*ldb : k*ldb+n]
728 for ia, va := range a[k*lda+k+1 : k*lda+m] {
730 btmp := b[i*ldb : i*ldb+n]
733 f64.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
741 for j := 0; j < n; j++ {
748 for k := 0; k < m; k++ {
749 btmpk := b[k*ldb : k*ldb+n]
750 for i, va := range a[k*lda : k*lda+k] {
751 btmp := b[i*ldb : i*ldb+n]
754 f64.AxpyUnitaryTo(btmp, tmp, btmpk, btmp)
762 for j := 0; j < n; j++ {
769 // Cases where a is on the right
770 if tA == blas.NoTrans {
771 if ul == blas.Upper {
772 for i := 0; i < m; i++ {
773 btmp := b[i*ldb : i*ldb+n]
774 for k := n - 1; k >= 0; k-- {
775 tmp := alpha * btmp[k]
779 btmp[k] *= a[k*lda+k]
781 for ja, v := range a[k*lda+k+1 : k*lda+n] {
790 for i := 0; i < m; i++ {
791 btmp := b[i*ldb : i*ldb+n]
792 for k := 0; k < n; k++ {
793 tmp := alpha * btmp[k]
797 btmp[k] *= a[k*lda+k]
799 f64.AxpyUnitaryTo(btmp, tmp, a[k*lda:k*lda+k], btmp)
805 // Cases where a is transposed.
806 if ul == blas.Upper {
807 for i := 0; i < m; i++ {
808 btmp := b[i*ldb : i*ldb+n]
809 for j, vb := range btmp {
814 tmp += f64.DotUnitary(a[j*lda+j+1:j*lda+n], btmp[j+1:n])
815 btmp[j] = alpha * tmp
820 for i := 0; i < m; i++ {
821 btmp := b[i*ldb : i*ldb+n]
822 for j := n - 1; j >= 0; j-- {
827 tmp += f64.DotUnitary(a[j*lda:j*lda+j], btmp[:j])
828 btmp[j] = alpha * tmp