OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / gonum / sgemm.go
1 // Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT.
2
3 // Copyright ©2014 The Gonum Authors. All rights reserved.
4 // Use of this source code is governed by a BSD-style
5 // license that can be found in the LICENSE file.
6
7 package gonum
8
9 import (
10         "runtime"
11         "sync"
12
13         "gonum.org/v1/gonum/blas"
14         "gonum.org/v1/gonum/internal/asm/f32"
15 )
16
17 // Sgemm computes
18 //  C = beta * C + alpha * A * B,
19 // where A, B, and C are dense matrices, and alpha and beta are scalars.
20 // tA and tB specify whether A or B are transposed.
21 //
22 // Float32 implementations are autogenerated and not directly tested.
23 func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
24         if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
25                 panic(badTranspose)
26         }
27         if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
28                 panic(badTranspose)
29         }
30         aTrans := tA == blas.Trans || tA == blas.ConjTrans
31         if aTrans {
32                 checkSMatrix('a', k, m, a, lda)
33         } else {
34                 checkSMatrix('a', m, k, a, lda)
35         }
36         bTrans := tB == blas.Trans || tB == blas.ConjTrans
37         if bTrans {
38                 checkSMatrix('b', n, k, b, ldb)
39         } else {
40                 checkSMatrix('b', k, n, b, ldb)
41         }
42         checkSMatrix('c', m, n, c, ldc)
43
44         // scale c
45         if beta != 1 {
46                 if beta == 0 {
47                         for i := 0; i < m; i++ {
48                                 ctmp := c[i*ldc : i*ldc+n]
49                                 for j := range ctmp {
50                                         ctmp[j] = 0
51                                 }
52                         }
53                 } else {
54                         for i := 0; i < m; i++ {
55                                 ctmp := c[i*ldc : i*ldc+n]
56                                 for j := range ctmp {
57                                         ctmp[j] *= beta
58                                 }
59                         }
60                 }
61         }
62
63         sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
64 }
65
66 func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
67         // dgemmParallel computes a parallel matrix multiplication by partitioning
68         // a and b into sub-blocks, and updating c with the multiplication of the sub-block
69         // In all cases,
70         // A = [        A_11    A_12 ...        A_1j
71         //                      A_21    A_22 ...        A_2j
72         //                              ...
73         //                      A_i1    A_i2 ...        A_ij]
74         //
75         // and same for B. All of the submatrix sizes are blockSize×blockSize except
76         // at the edges.
77         //
78         // In all cases, there is one dimension for each matrix along which
79         // C must be updated sequentially.
80         // Cij = \sum_k Aik Bki,        (A * B)
81         // Cij = \sum_k Aki Bkj,        (A^T * B)
82         // Cij = \sum_k Aik Bjk,        (A * B^T)
83         // Cij = \sum_k Aki Bjk,        (A^T * B^T)
84         //
85         // This code computes one {i, j} block sequentially along the k dimension,
86         // and computes all of the {i, j} blocks concurrently. This
87         // partitioning allows Cij to be updated in-place without race-conditions.
88         // Instead of launching a goroutine for each possible concurrent computation,
89         // a number of worker goroutines are created and channels are used to pass
90         // available and completed cases.
91         //
92         // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
93         // multiplies, though this code does not copy matrices to attempt to eliminate
94         // cache misses.
95
96         maxKLen := k
97         parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
98         if parBlocks < minParBlock {
99                 // The matrix multiplication is small in the dimensions where it can be
100                 // computed concurrently. Just do it in serial.
101                 sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
102                 return
103         }
104
105         nWorkers := runtime.GOMAXPROCS(0)
106         if parBlocks < nWorkers {
107                 nWorkers = parBlocks
108         }
109         // There is a tradeoff between the workers having to wait for work
110         // and a large buffer making operations slow.
111         buf := buffMul * nWorkers
112         if buf > parBlocks {
113                 buf = parBlocks
114         }
115
116         sendChan := make(chan subMul, buf)
117
118         // Launch workers. A worker receives an {i, j} submatrix of c, and computes
119         // A_ik B_ki (or the transposed version) storing the result in c_ij. When the
120         // channel is finally closed, it signals to the waitgroup that it has finished
121         // computing.
122         var wg sync.WaitGroup
123         for i := 0; i < nWorkers; i++ {
124                 wg.Add(1)
125                 go func() {
126                         defer wg.Done()
127                         // Make local copies of otherwise global variables to reduce shared memory.
128                         // This has a noticeable effect on benchmarks in some cases.
129                         alpha := alpha
130                         aTrans := aTrans
131                         bTrans := bTrans
132                         m := m
133                         n := n
134                         for sub := range sendChan {
135                                 i := sub.i
136                                 j := sub.j
137                                 leni := blockSize
138                                 if i+leni > m {
139                                         leni = m - i
140                                 }
141                                 lenj := blockSize
142                                 if j+lenj > n {
143                                         lenj = n - j
144                                 }
145
146                                 cSub := sliceView32(c, ldc, i, j, leni, lenj)
147
148                                 // Compute A_ik B_kj for all k
149                                 for k := 0; k < maxKLen; k += blockSize {
150                                         lenk := blockSize
151                                         if k+lenk > maxKLen {
152                                                 lenk = maxKLen - k
153                                         }
154                                         var aSub, bSub []float32
155                                         if aTrans {
156                                                 aSub = sliceView32(a, lda, k, i, lenk, leni)
157                                         } else {
158                                                 aSub = sliceView32(a, lda, i, k, leni, lenk)
159                                         }
160                                         if bTrans {
161                                                 bSub = sliceView32(b, ldb, j, k, lenj, lenk)
162                                         } else {
163                                                 bSub = sliceView32(b, ldb, k, j, lenk, lenj)
164                                         }
165                                         sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
166                                 }
167                         }
168                 }()
169         }
170
171         // Send out all of the {i, j} subblocks for computation.
172         for i := 0; i < m; i += blockSize {
173                 for j := 0; j < n; j += blockSize {
174                         sendChan <- subMul{
175                                 i: i,
176                                 j: j,
177                         }
178                 }
179         }
180         close(sendChan)
181         wg.Wait()
182 }
183
184 // sgemmSerial is serial matrix multiply
185 func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
186         switch {
187         case !aTrans && !bTrans:
188                 sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
189                 return
190         case aTrans && !bTrans:
191                 sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
192                 return
193         case !aTrans && bTrans:
194                 sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
195                 return
196         case aTrans && bTrans:
197                 sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
198                 return
199         default:
200                 panic("unreachable")
201         }
202 }
203
204 // sgemmSerial where neither a nor b are transposed
205 func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
206         // This style is used instead of the literal [i*stride +j]) is used because
207         // approximately 5 times faster as of go 1.3.
208         for i := 0; i < m; i++ {
209                 ctmp := c[i*ldc : i*ldc+n]
210                 for l, v := range a[i*lda : i*lda+k] {
211                         tmp := alpha * v
212                         if tmp != 0 {
213                                 f32.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
214                         }
215                 }
216         }
217 }
218
219 // sgemmSerial where neither a is transposed and b is not
220 func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
221         // This style is used instead of the literal [i*stride +j]) is used because
222         // approximately 5 times faster as of go 1.3.
223         for l := 0; l < k; l++ {
224                 btmp := b[l*ldb : l*ldb+n]
225                 for i, v := range a[l*lda : l*lda+m] {
226                         tmp := alpha * v
227                         if tmp != 0 {
228                                 ctmp := c[i*ldc : i*ldc+n]
229                                 f32.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
230                         }
231                 }
232         }
233 }
234
235 // sgemmSerial where neither a is not transposed and b is
236 func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
237         // This style is used instead of the literal [i*stride +j]) is used because
238         // approximately 5 times faster as of go 1.3.
239         for i := 0; i < m; i++ {
240                 atmp := a[i*lda : i*lda+k]
241                 ctmp := c[i*ldc : i*ldc+n]
242                 for j := 0; j < n; j++ {
243                         ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
244                 }
245         }
246 }
247
248 // sgemmSerial where both are transposed
249 func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
250         // This style is used instead of the literal [i*stride +j]) is used because
251         // approximately 5 times faster as of go 1.3.
252         for l := 0; l < k; l++ {
253                 for i, v := range a[l*lda : l*lda+m] {
254                         tmp := alpha * v
255                         if tmp != 0 {
256                                 ctmp := c[i*ldc : i*ldc+n]
257                                 f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
258                         }
259                 }
260         }
261 }
262
263 func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
264         return a[i*lda+j : (i+r-1)*lda+j+c]
265 }