OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / product.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import "fmt"
8
9 // Product calculates the product of the given factors and places the result in
10 // the receiver. The order of multiplication operations is optimized to minimize
11 // the number of floating point operations on the basis that all matrix
12 // multiplications are general.
13 func (m *Dense) Product(factors ...Matrix) {
14         // The operation order optimisation is the naive O(n^3) dynamic
15         // programming approach and does not take into consideration
16         // finer-grained optimisations that might be available.
17         //
18         // TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
19         // algorithms that are available. e.g.
20         //
21         // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
22         //
23         // In the case that this is replaced, retain this code in
24         // tests to compare against.
25
26         r, c := m.Dims()
27         switch len(factors) {
28         case 0:
29                 if r != 0 || c != 0 {
30                         panic(ErrShape)
31                 }
32                 return
33         case 1:
34                 m.reuseAs(factors[0].Dims())
35                 m.Copy(factors[0])
36                 return
37         case 2:
38                 // Don't do work that we know the answer to.
39                 m.Mul(factors[0], factors[1])
40                 return
41         }
42
43         p := newMultiplier(m, factors)
44         p.optimize()
45         result := p.multiply()
46         m.reuseAs(result.Dims())
47         m.Copy(result)
48         putWorkspace(result)
49 }
50
51 // debugProductWalk enables debugging output for Product.
52 const debugProductWalk = false
53
54 // multiplier performs operation order optimisation and tree traversal.
55 type multiplier struct {
56         // factors is the ordered set of
57         // factors to multiply.
58         factors []Matrix
59         // dims is the chain of factor
60         // dimensions.
61         dims []int
62
63         // table contains the dynamic
64         // programming costs and subchain
65         // division indices.
66         table table
67 }
68
69 func newMultiplier(m *Dense, factors []Matrix) *multiplier {
70         // Check size early, but don't yet
71         // allocate data for m.
72         r, c := m.Dims()
73         fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
74         if !m.IsZero() {
75                 if fr != r {
76                         panic(ErrShape)
77                 }
78                 if _, lc := factors[len(factors)-1].Dims(); lc != c {
79                         panic(ErrShape)
80                 }
81         }
82
83         dims := make([]int, len(factors)+1)
84         dims[0] = r
85         dims[len(dims)-1] = c
86         pc := fc
87         for i, f := range factors[1:] {
88                 cr, cc := f.Dims()
89                 dims[i+1] = cr
90                 if pc != cr {
91                         panic(ErrShape)
92                 }
93                 pc = cc
94         }
95
96         return &multiplier{
97                 factors: factors,
98                 dims:    dims,
99                 table:   newTable(len(factors)),
100         }
101 }
102
103 // optimize determines an optimal matrix multiply operation order.
104 func (p *multiplier) optimize() {
105         if debugProductWalk {
106                 fmt.Printf("chain dims: %v\n", p.dims)
107         }
108         const maxInt = int(^uint(0) >> 1)
109         for f := 1; f < len(p.factors); f++ {
110                 for i := 0; i < len(p.factors)-f; i++ {
111                         j := i + f
112                         p.table.set(i, j, entry{cost: maxInt})
113                         for k := i; k < j; k++ {
114                                 cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
115                                 if cost < p.table.at(i, j).cost {
116                                         p.table.set(i, j, entry{cost: cost, k: k})
117                                 }
118                         }
119                 }
120         }
121 }
122
123 // multiply walks the optimal operation tree found by optimize,
124 // leaving the final result in the stack. It returns the
125 // product, which may be copied but should be returned to
126 // the workspace pool.
127 func (p *multiplier) multiply() *Dense {
128         result, _ := p.multiplySubchain(0, len(p.factors)-1)
129         if debugProductWalk {
130                 r, c := result.Dims()
131                 fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
132         }
133         return result.(*Dense)
134 }
135
136 func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
137         if i == j {
138                 return p.factors[i], false
139         }
140
141         a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
142         b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
143
144         ar, ac := a.Dims()
145         br, bc := b.Dims()
146         if ac != br {
147                 // Panic with a string since this
148                 // is not a user-facing panic.
149                 panic(ErrShape.Error())
150         }
151
152         if debugProductWalk {
153                 fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
154                         i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
155         }
156
157         r := getWorkspace(ar, bc, false)
158         r.Mul(a, b)
159         if aTmp {
160                 putWorkspace(a.(*Dense))
161         }
162         if bTmp {
163                 putWorkspace(b.(*Dense))
164         }
165         return r, true
166 }
167
168 type entry struct {
169         k    int // is the chain subdivision index.
170         cost int // cost is the cost of the operation.
171 }
172
173 // table is a row major n×n dynamic programming table.
174 type table struct {
175         n       int
176         entries []entry
177 }
178
179 func newTable(n int) table {
180         return table{n: n, entries: make([]entry, n*n)}
181 }
182
183 func (t table) at(i, j int) entry     { return t.entries[i*t.n+j] }
184 func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
185
186 type result bool
187
188 func (r result) String() string {
189         if r {
190                 return " (popped result)"
191         }
192         return ""
193 }