// Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT. // Copyright ©2014 The Gonum Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package gonum import ( "runtime" "sync" "gonum.org/v1/gonum/blas" "gonum.org/v1/gonum/internal/asm/f32" ) // Sgemm computes // C = beta * C + alpha * A * B, // where A, B, and C are dense matrices, and alpha and beta are scalars. // tA and tB specify whether A or B are transposed. // // Float32 implementations are autogenerated and not directly tested. func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) { if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans { panic(badTranspose) } if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans { panic(badTranspose) } aTrans := tA == blas.Trans || tA == blas.ConjTrans if aTrans { checkSMatrix('a', k, m, a, lda) } else { checkSMatrix('a', m, k, a, lda) } bTrans := tB == blas.Trans || tB == blas.ConjTrans if bTrans { checkSMatrix('b', n, k, b, ldb) } else { checkSMatrix('b', k, n, b, ldb) } checkSMatrix('c', m, n, c, ldc) // scale c if beta != 1 { if beta == 0 { for i := 0; i < m; i++ { ctmp := c[i*ldc : i*ldc+n] for j := range ctmp { ctmp[j] = 0 } } } else { for i := 0; i < m; i++ { ctmp := c[i*ldc : i*ldc+n] for j := range ctmp { ctmp[j] *= beta } } } } sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) } func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { // dgemmParallel computes a parallel matrix multiplication by partitioning // a and b into sub-blocks, and updating c with the multiplication of the sub-block // In all cases, // A = [ A_11 A_12 ... A_1j // A_21 A_22 ... A_2j // ... // A_i1 A_i2 ... A_ij] // // and same for B. All of the submatrix sizes are blockSize×blockSize except // at the edges. // // In all cases, there is one dimension for each matrix along which // C must be updated sequentially. // Cij = \sum_k Aik Bki, (A * B) // Cij = \sum_k Aki Bkj, (A^T * B) // Cij = \sum_k Aik Bjk, (A * B^T) // Cij = \sum_k Aki Bjk, (A^T * B^T) // // This code computes one {i, j} block sequentially along the k dimension, // and computes all of the {i, j} blocks concurrently. This // partitioning allows Cij to be updated in-place without race-conditions. // Instead of launching a goroutine for each possible concurrent computation, // a number of worker goroutines are created and channels are used to pass // available and completed cases. // // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix // multiplies, though this code does not copy matrices to attempt to eliminate // cache misses. maxKLen := k parBlocks := blocks(m, blockSize) * blocks(n, blockSize) if parBlocks < minParBlock { // The matrix multiplication is small in the dimensions where it can be // computed concurrently. Just do it in serial. sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) return } nWorkers := runtime.GOMAXPROCS(0) if parBlocks < nWorkers { nWorkers = parBlocks } // There is a tradeoff between the workers having to wait for work // and a large buffer making operations slow. buf := buffMul * nWorkers if buf > parBlocks { buf = parBlocks } sendChan := make(chan subMul, buf) // Launch workers. A worker receives an {i, j} submatrix of c, and computes // A_ik B_ki (or the transposed version) storing the result in c_ij. When the // channel is finally closed, it signals to the waitgroup that it has finished // computing. var wg sync.WaitGroup for i := 0; i < nWorkers; i++ { wg.Add(1) go func() { defer wg.Done() // Make local copies of otherwise global variables to reduce shared memory. // This has a noticeable effect on benchmarks in some cases. alpha := alpha aTrans := aTrans bTrans := bTrans m := m n := n for sub := range sendChan { i := sub.i j := sub.j leni := blockSize if i+leni > m { leni = m - i } lenj := blockSize if j+lenj > n { lenj = n - j } cSub := sliceView32(c, ldc, i, j, leni, lenj) // Compute A_ik B_kj for all k for k := 0; k < maxKLen; k += blockSize { lenk := blockSize if k+lenk > maxKLen { lenk = maxKLen - k } var aSub, bSub []float32 if aTrans { aSub = sliceView32(a, lda, k, i, lenk, leni) } else { aSub = sliceView32(a, lda, i, k, leni, lenk) } if bTrans { bSub = sliceView32(b, ldb, j, k, lenj, lenk) } else { bSub = sliceView32(b, ldb, k, j, lenk, lenj) } sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) } } }() } // Send out all of the {i, j} subblocks for computation. for i := 0; i < m; i += blockSize { for j := 0; j < n; j += blockSize { sendChan <- subMul{ i: i, j: j, } } } close(sendChan) wg.Wait() } // sgemmSerial is serial matrix multiply func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { switch { case !aTrans && !bTrans: sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) return case aTrans && !bTrans: sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) return case !aTrans && bTrans: sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) return case aTrans && bTrans: sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) return default: panic("unreachable") } } // sgemmSerial where neither a nor b are transposed func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { // This style is used instead of the literal [i*stride +j]) is used because // approximately 5 times faster as of go 1.3. for i := 0; i < m; i++ { ctmp := c[i*ldc : i*ldc+n] for l, v := range a[i*lda : i*lda+k] { tmp := alpha * v if tmp != 0 { f32.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp) } } } } // sgemmSerial where neither a is transposed and b is not func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { // This style is used instead of the literal [i*stride +j]) is used because // approximately 5 times faster as of go 1.3. for l := 0; l < k; l++ { btmp := b[l*ldb : l*ldb+n] for i, v := range a[l*lda : l*lda+m] { tmp := alpha * v if tmp != 0 { ctmp := c[i*ldc : i*ldc+n] f32.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp) } } } } // sgemmSerial where neither a is not transposed and b is func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { // This style is used instead of the literal [i*stride +j]) is used because // approximately 5 times faster as of go 1.3. for i := 0; i < m; i++ { atmp := a[i*lda : i*lda+k] ctmp := c[i*ldc : i*ldc+n] for j := 0; j < n; j++ { ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k]) } } } // sgemmSerial where both are transposed func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) { // This style is used instead of the literal [i*stride +j]) is used because // approximately 5 times faster as of go 1.3. for l := 0; l < k; l++ { for i, v := range a[l*lda : l*lda+m] { tmp := alpha * v if tmp != 0 { ctmp := c[i*ldc : i*ldc+n] f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) } } } } func sliceView32(a []float32, lda, i, j, r, c int) []float32 { return a[i*lda+j : (i+r-1)*lda+j+c] }