OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / blas / gonum / dgemm.go
1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package gonum
6
7 import (
8         "runtime"
9         "sync"
10
11         "gonum.org/v1/gonum/blas"
12         "gonum.org/v1/gonum/internal/asm/f64"
13 )
14
15 // Dgemm computes
16 //  C = beta * C + alpha * A * B,
17 // where A, B, and C are dense matrices, and alpha and beta are scalars.
18 // tA and tB specify whether A or B are transposed.
19 func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
20         if tA != blas.NoTrans && tA != blas.Trans && tA != blas.ConjTrans {
21                 panic(badTranspose)
22         }
23         if tB != blas.NoTrans && tB != blas.Trans && tB != blas.ConjTrans {
24                 panic(badTranspose)
25         }
26         aTrans := tA == blas.Trans || tA == blas.ConjTrans
27         if aTrans {
28                 checkDMatrix('a', k, m, a, lda)
29         } else {
30                 checkDMatrix('a', m, k, a, lda)
31         }
32         bTrans := tB == blas.Trans || tB == blas.ConjTrans
33         if bTrans {
34                 checkDMatrix('b', n, k, b, ldb)
35         } else {
36                 checkDMatrix('b', k, n, b, ldb)
37         }
38         checkDMatrix('c', m, n, c, ldc)
39
40         // scale c
41         if beta != 1 {
42                 if beta == 0 {
43                         for i := 0; i < m; i++ {
44                                 ctmp := c[i*ldc : i*ldc+n]
45                                 for j := range ctmp {
46                                         ctmp[j] = 0
47                                 }
48                         }
49                 } else {
50                         for i := 0; i < m; i++ {
51                                 ctmp := c[i*ldc : i*ldc+n]
52                                 for j := range ctmp {
53                                         ctmp[j] *= beta
54                                 }
55                         }
56                 }
57         }
58
59         dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
60 }
61
62 func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
63         // dgemmParallel computes a parallel matrix multiplication by partitioning
64         // a and b into sub-blocks, and updating c with the multiplication of the sub-block
65         // In all cases,
66         // A = [        A_11    A_12 ...        A_1j
67         //                      A_21    A_22 ...        A_2j
68         //                              ...
69         //                      A_i1    A_i2 ...        A_ij]
70         //
71         // and same for B. All of the submatrix sizes are blockSize×blockSize except
72         // at the edges.
73         //
74         // In all cases, there is one dimension for each matrix along which
75         // C must be updated sequentially.
76         // Cij = \sum_k Aik Bki,        (A * B)
77         // Cij = \sum_k Aki Bkj,        (A^T * B)
78         // Cij = \sum_k Aik Bjk,        (A * B^T)
79         // Cij = \sum_k Aki Bjk,        (A^T * B^T)
80         //
81         // This code computes one {i, j} block sequentially along the k dimension,
82         // and computes all of the {i, j} blocks concurrently. This
83         // partitioning allows Cij to be updated in-place without race-conditions.
84         // Instead of launching a goroutine for each possible concurrent computation,
85         // a number of worker goroutines are created and channels are used to pass
86         // available and completed cases.
87         //
88         // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
89         // multiplies, though this code does not copy matrices to attempt to eliminate
90         // cache misses.
91
92         maxKLen := k
93         parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
94         if parBlocks < minParBlock {
95                 // The matrix multiplication is small in the dimensions where it can be
96                 // computed concurrently. Just do it in serial.
97                 dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
98                 return
99         }
100
101         nWorkers := runtime.GOMAXPROCS(0)
102         if parBlocks < nWorkers {
103                 nWorkers = parBlocks
104         }
105         // There is a tradeoff between the workers having to wait for work
106         // and a large buffer making operations slow.
107         buf := buffMul * nWorkers
108         if buf > parBlocks {
109                 buf = parBlocks
110         }
111
112         sendChan := make(chan subMul, buf)
113
114         // Launch workers. A worker receives an {i, j} submatrix of c, and computes
115         // A_ik B_ki (or the transposed version) storing the result in c_ij. When the
116         // channel is finally closed, it signals to the waitgroup that it has finished
117         // computing.
118         var wg sync.WaitGroup
119         for i := 0; i < nWorkers; i++ {
120                 wg.Add(1)
121                 go func() {
122                         defer wg.Done()
123                         // Make local copies of otherwise global variables to reduce shared memory.
124                         // This has a noticeable effect on benchmarks in some cases.
125                         alpha := alpha
126                         aTrans := aTrans
127                         bTrans := bTrans
128                         m := m
129                         n := n
130                         for sub := range sendChan {
131                                 i := sub.i
132                                 j := sub.j
133                                 leni := blockSize
134                                 if i+leni > m {
135                                         leni = m - i
136                                 }
137                                 lenj := blockSize
138                                 if j+lenj > n {
139                                         lenj = n - j
140                                 }
141
142                                 cSub := sliceView64(c, ldc, i, j, leni, lenj)
143
144                                 // Compute A_ik B_kj for all k
145                                 for k := 0; k < maxKLen; k += blockSize {
146                                         lenk := blockSize
147                                         if k+lenk > maxKLen {
148                                                 lenk = maxKLen - k
149                                         }
150                                         var aSub, bSub []float64
151                                         if aTrans {
152                                                 aSub = sliceView64(a, lda, k, i, lenk, leni)
153                                         } else {
154                                                 aSub = sliceView64(a, lda, i, k, leni, lenk)
155                                         }
156                                         if bTrans {
157                                                 bSub = sliceView64(b, ldb, j, k, lenj, lenk)
158                                         } else {
159                                                 bSub = sliceView64(b, ldb, k, j, lenk, lenj)
160                                         }
161                                         dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
162                                 }
163                         }
164                 }()
165         }
166
167         // Send out all of the {i, j} subblocks for computation.
168         for i := 0; i < m; i += blockSize {
169                 for j := 0; j < n; j += blockSize {
170                         sendChan <- subMul{
171                                 i: i,
172                                 j: j,
173                         }
174                 }
175         }
176         close(sendChan)
177         wg.Wait()
178 }
179
180 // dgemmSerial is serial matrix multiply
181 func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
182         switch {
183         case !aTrans && !bTrans:
184                 dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
185                 return
186         case aTrans && !bTrans:
187                 dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
188                 return
189         case !aTrans && bTrans:
190                 dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
191                 return
192         case aTrans && bTrans:
193                 dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
194                 return
195         default:
196                 panic("unreachable")
197         }
198 }
199
200 // dgemmSerial where neither a nor b are transposed
201 func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
202         // This style is used instead of the literal [i*stride +j]) is used because
203         // approximately 5 times faster as of go 1.3.
204         for i := 0; i < m; i++ {
205                 ctmp := c[i*ldc : i*ldc+n]
206                 for l, v := range a[i*lda : i*lda+k] {
207                         tmp := alpha * v
208                         if tmp != 0 {
209                                 f64.AxpyUnitaryTo(ctmp, tmp, b[l*ldb:l*ldb+n], ctmp)
210                         }
211                 }
212         }
213 }
214
215 // dgemmSerial where neither a is transposed and b is not
216 func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
217         // This style is used instead of the literal [i*stride +j]) is used because
218         // approximately 5 times faster as of go 1.3.
219         for l := 0; l < k; l++ {
220                 btmp := b[l*ldb : l*ldb+n]
221                 for i, v := range a[l*lda : l*lda+m] {
222                         tmp := alpha * v
223                         if tmp != 0 {
224                                 ctmp := c[i*ldc : i*ldc+n]
225                                 f64.AxpyUnitaryTo(ctmp, tmp, btmp, ctmp)
226                         }
227                 }
228         }
229 }
230
231 // dgemmSerial where neither a is not transposed and b is
232 func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
233         // This style is used instead of the literal [i*stride +j]) is used because
234         // approximately 5 times faster as of go 1.3.
235         for i := 0; i < m; i++ {
236                 atmp := a[i*lda : i*lda+k]
237                 ctmp := c[i*ldc : i*ldc+n]
238                 for j := 0; j < n; j++ {
239                         ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
240                 }
241         }
242 }
243
244 // dgemmSerial where both are transposed
245 func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
246         // This style is used instead of the literal [i*stride +j]) is used because
247         // approximately 5 times faster as of go 1.3.
248         for l := 0; l < k; l++ {
249                 for i, v := range a[l*lda : l*lda+m] {
250                         tmp := alpha * v
251                         if tmp != 0 {
252                                 ctmp := c[i*ldc : i*ldc+n]
253                                 f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
254                         }
255                 }
256         }
257 }
258
259 func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
260         return a[i*lda+j : (i+r-1)*lda+j+c]
261 }