1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "gonum.org/v1/gonum/blas"
12 "gonum.org/v1/gonum/internal/asm/f64"
16 // C = beta * C + alpha * A * B,
17 // where A, B, and C are dense matrices, and alpha and beta are scalars.
18 // tA and tB specify whether A or B are transposed.
19 func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
20 if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
23 if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
26 aTrans := tA == blas.Trans || tA == blas.ConjTrans
28 checkDMatrix('a', k, m, a, lda)
30 checkDMatrix('a', m, k, a, lda)
32 bTrans := tB == blas.Trans || tB == blas.ConjTrans
34 checkDMatrix('b', n, k, b, ldb)
36 checkDMatrix('b', k, n, b, ldb)
38 checkDMatrix('c', m, n, c, ldc)
43 for i := 0; i < m; i++ {
44 ctmp := c[i*ldc : i*ldc+n]
50 for i := 0; i < m; i++ {
51 ctmp := c[i*ldc : i*ldc+n]
59 dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
62 func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
63 // dgemmParallel computes a parallel matrix multiplication by partitioning
64 // a and b into sub-blocks, and updating c with the multiplication of the sub-block
66 // A = [ A_11 A_12 ... A_1j
69 // A_i1 A_i2 ... A_ij]
71 // and same for B. All of the submatrix sizes are blockSize×blockSize except
74 // In all cases, there is one dimension for each matrix along which
75 // C must be updated sequentially.
76 // Cij = \sum_k Aik Bki, (A * B)
77 // Cij = \sum_k Aki Bkj, (A^T * B)
78 // Cij = \sum_k Aik Bjk, (A * B^T)
79 // Cij = \sum_k Aki Bjk, (A^T * B^T)
81 // This code computes one {i, j} block sequentially along the k dimension,
82 // and computes all of the {i, j} blocks concurrently. This
83 // partitioning allows Cij to be updated in-place without race-conditions.
84 // Instead of launching a goroutine for each possible concurrent computation,
85 // a number of worker goroutines are created and channels are used to pass
86 // available and completed cases.
88 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
89 // multiplies, though this code does not copy matrices to attempt to eliminate
93 parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
94 if parBlocks < minParBlock {
95 // The matrix multiplication is small in the dimensions where it can be
96 // computed concurrently. Just do it in serial.
97 dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
101 nWorkers := runtime.GOMAXPROCS(0)
102 if parBlocks < nWorkers {
105 // There is a tradeoff between the workers having to wait for work
106 // and a large buffer making operations slow.
107 buf := buffMul * nWorkers
112 sendChan := make(chan subMul, buf)
114 // Launch workers. A worker receives an {i, j} submatrix of c, and computes
115 // A_ik B_ki (or the transposed version) storing the result in c_ij. When the
116 // channel is finally closed, it signals to the waitgroup that it has finished
118 var wg sync.WaitGroup
119 for i := 0; i < nWorkers; i++ {
123 // Make local copies of otherwise global variables to reduce shared memory.
124 // This has a noticeable effect on benchmarks in some cases.
130 for sub := range sendChan {
142 cSub := sliceView64(c, ldc, i, j, leni, lenj)
144 // Compute A_ik B_kj for all k
145 for k := 0; k < maxKLen; k += blockSize {
147 if k+lenk > maxKLen {
150 var aSub, bSub []float64
152 aSub = sliceView64(a, lda, k, i, lenk, leni)
154 aSub = sliceView64(a, lda, i, k, leni, lenk)
157 bSub = sliceView64(b, ldb, j, k, lenj, lenk)
159 bSub = sliceView64(b, ldb, k, j, lenk, lenj)
161 dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
167 // Send out all of the {i, j} subblocks for computation.
168 for i := 0; i < m; i += blockSize {
169 for j := 0; j < n; j += blockSize {
180 // dgemmSerial is serial matrix multiply
181 func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
183 case !aTrans && !bTrans:
184 dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
186 case aTrans && !bTrans:
187 dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
189 case !aTrans && bTrans:
190 dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
192 case aTrans && bTrans:
193 dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
200 // dgemmSerial where neither a nor b are transposed
201 func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
202 // This style is used instead of the literal [i*stride +j]) is used because
203 // approximately 5 times faster as of go 1.3.
204 for i := 0; i < m; i++ {
205 ctmp := c[i*ldc : i*ldc+n]
206 for l, v := range a[i*lda : i*lda+k] {
209 f64.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
215 // dgemmSerial where neither a is transposed and b is not
216 func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
217 // This style is used instead of the literal [i*stride +j]) is used because
218 // approximately 5 times faster as of go 1.3.
219 for l := 0; l < k; l++ {
220 btmp := b[l*ldb : l*ldb+n]
221 for i, v := range a[l*lda : l*lda+m] {
224 ctmp := c[i*ldc : i*ldc+n]
225 f64.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
231 // dgemmSerial where neither a is not transposed and b is
232 func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
233 // This style is used instead of the literal [i*stride +j]) is used because
234 // approximately 5 times faster as of go 1.3.
235 for i := 0; i < m; i++ {
236 atmp := a[i*lda : i*lda+k]
237 ctmp := c[i*ldc : i*ldc+n]
238 for j := 0; j < n; j++ {
239 ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
244 // dgemmSerial where both are transposed
245 func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
246 // This style is used instead of the literal [i*stride +j]) is used because
247 // approximately 5 times faster as of go 1.3.
248 for l := 0; l < k; l++ {
249 for i, v := range a[l*lda : l*lda+m] {
252 ctmp := c[i*ldc : i*ldc+n]
253 f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
259 func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
260 return a[i*lda+j : (i+r-1)*lda+j+c]