1 // Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT.
3 // Copyright ©2014 The Gonum Authors. All rights reserved.
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file.
13 "gonum.org/v1/gonum/blas"
14 "gonum.org/v1/gonum/internal/asm/f32"
18 // C = beta * C + alpha * A * B,
19 // where A, B, and C are dense matrices, and alpha and beta are scalars.
20 // tA and tB specify whether A or B are transposed.
22 // Float32 implementations are autogenerated and not directly tested.
23 func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
24 if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
27 if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
30 aTrans := tA == blas.Trans || tA == blas.ConjTrans
32 checkSMatrix('a', k, m, a, lda)
34 checkSMatrix('a', m, k, a, lda)
36 bTrans := tB == blas.Trans || tB == blas.ConjTrans
38 checkSMatrix('b', n, k, b, ldb)
40 checkSMatrix('b', k, n, b, ldb)
42 checkSMatrix('c', m, n, c, ldc)
47 for i := 0; i < m; i++ {
48 ctmp := c[i*ldc : i*ldc+n]
54 for i := 0; i < m; i++ {
55 ctmp := c[i*ldc : i*ldc+n]
63 sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
66 func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
67 // dgemmParallel computes a parallel matrix multiplication by partitioning
68 // a and b into sub-blocks, and updating c with the multiplication of the sub-block
70 // A = [ A_11 A_12 ... A_1j
73 // A_i1 A_i2 ... A_ij]
75 // and same for B. All of the submatrix sizes are blockSize×blockSize except
78 // In all cases, there is one dimension for each matrix along which
79 // C must be updated sequentially.
80 // Cij = \sum_k Aik Bki, (A * B)
81 // Cij = \sum_k Aki Bkj, (A^T * B)
82 // Cij = \sum_k Aik Bjk, (A * B^T)
83 // Cij = \sum_k Aki Bjk, (A^T * B^T)
85 // This code computes one {i, j} block sequentially along the k dimension,
86 // and computes all of the {i, j} blocks concurrently. This
87 // partitioning allows Cij to be updated in-place without race-conditions.
88 // Instead of launching a goroutine for each possible concurrent computation,
89 // a number of worker goroutines are created and channels are used to pass
90 // available and completed cases.
92 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
93 // multiplies, though this code does not copy matrices to attempt to eliminate
97 parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
98 if parBlocks < minParBlock {
99 // The matrix multiplication is small in the dimensions where it can be
100 // computed concurrently. Just do it in serial.
101 sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
105 nWorkers := runtime.GOMAXPROCS(0)
106 if parBlocks < nWorkers {
109 // There is a tradeoff between the workers having to wait for work
110 // and a large buffer making operations slow.
111 buf := buffMul * nWorkers
116 sendChan := make(chan subMul, buf)
118 // Launch workers. A worker receives an {i, j} submatrix of c, and computes
119 // A_ik B_ki (or the transposed version) storing the result in c_ij. When the
120 // channel is finally closed, it signals to the waitgroup that it has finished
122 var wg sync.WaitGroup
123 for i := 0; i < nWorkers; i++ {
127 // Make local copies of otherwise global variables to reduce shared memory.
128 // This has a noticeable effect on benchmarks in some cases.
134 for sub := range sendChan {
146 cSub := sliceView32(c, ldc, i, j, leni, lenj)
148 // Compute A_ik B_kj for all k
149 for k := 0; k < maxKLen; k += blockSize {
151 if k+lenk > maxKLen {
154 var aSub, bSub []float32
156 aSub = sliceView32(a, lda, k, i, lenk, leni)
158 aSub = sliceView32(a, lda, i, k, leni, lenk)
161 bSub = sliceView32(b, ldb, j, k, lenj, lenk)
163 bSub = sliceView32(b, ldb, k, j, lenk, lenj)
165 sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
171 // Send out all of the {i, j} subblocks for computation.
172 for i := 0; i < m; i += blockSize {
173 for j := 0; j < n; j += blockSize {
184 // sgemmSerial is serial matrix multiply
185 func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
187 case !aTrans && !bTrans:
188 sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
190 case aTrans && !bTrans:
191 sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
193 case !aTrans && bTrans:
194 sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
196 case aTrans && bTrans:
197 sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
204 // sgemmSerial where neither a nor b are transposed
205 func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
206 // This style is used instead of the literal [i*stride +j]) is used because
207 // approximately 5 times faster as of go 1.3.
208 for i := 0; i < m; i++ {
209 ctmp := c[i*ldc : i*ldc+n]
210 for l, v := range a[i*lda : i*lda+k] {
213 f32.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
219 // sgemmSerial where neither a is transposed and b is not
220 func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
221 // This style is used instead of the literal [i*stride +j]) is used because
222 // approximately 5 times faster as of go 1.3.
223 for l := 0; l < k; l++ {
224 btmp := b[l*ldb : l*ldb+n]
225 for i, v := range a[l*lda : l*lda+m] {
228 ctmp := c[i*ldc : i*ldc+n]
229 f32.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
235 // sgemmSerial where neither a is not transposed and b is
236 func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
237 // This style is used instead of the literal [i*stride +j]) is used because
238 // approximately 5 times faster as of go 1.3.
239 for i := 0; i < m; i++ {
240 atmp := a[i*lda : i*lda+k]
241 ctmp := c[i*ldc : i*ldc+n]
242 for j := 0; j < n; j++ {
243 ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
248 // sgemmSerial where both are transposed
249 func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
250 // This style is used instead of the literal [i*stride +j]) is used because
251 // approximately 5 times faster as of go 1.3.
252 for l := 0; l < k; l++ {
253 for i, v := range a[l*lda : l*lda+m] {
256 ctmp := c[i*ldc : i*ldc+n]
257 f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
263 func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
264 return a[i*lda+j : (i+r-1)*lda+j+c]