1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
9 // Product calculates the product of the given factors and places the result in
10 // the receiver. The order of multiplication operations is optimized to minimize
11 // the number of floating point operations on the basis that all matrix
12 // multiplications are general.
13 func (m *Dense) Product(factors ...Matrix) {
14 // The operation order optimisation is the naive O(n^3) dynamic
15 // programming approach and does not take into consideration
16 // finer-grained optimisations that might be available.
18 // TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
19 // algorithms that are available. e.g.
21 // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
23 // In the case that this is replaced, retain this code in
24 // tests to compare against.
34 m.reuseAs(factors[0].Dims())
38 // Don't do work that we know the answer to.
39 m.Mul(factors[0], factors[1])
43 p := newMultiplier(m, factors)
45 result := p.multiply()
46 m.reuseAs(result.Dims())
51 // debugProductWalk enables debugging output for Product.
52 const debugProductWalk = false
54 // multiplier performs operation order optimisation and tree traversal.
55 type multiplier struct {
56 // factors is the ordered set of
57 // factors to multiply.
59 // dims is the chain of factor
63 // table contains the dynamic
64 // programming costs and subchain
69 func newMultiplier(m *Dense, factors []Matrix) *multiplier {
70 // Check size early, but don't yet
71 // allocate data for m.
73 fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
78 if _, lc := factors[len(factors)-1].Dims(); lc != c {
83 dims := make([]int, len(factors)+1)
87 for i, f := range factors[1:] {
99 table: newTable(len(factors)),
103 // optimize determines an optimal matrix multiply operation order.
104 func (p *multiplier) optimize() {
105 if debugProductWalk {
106 fmt.Printf("chain dims: %v\n", p.dims)
108 const maxInt = int(^uint(0) >> 1)
109 for f := 1; f < len(p.factors); f++ {
110 for i := 0; i < len(p.factors)-f; i++ {
112 p.table.set(i, j, entry{cost: maxInt})
113 for k := i; k < j; k++ {
114 cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
115 if cost < p.table.at(i, j).cost {
116 p.table.set(i, j, entry{cost: cost, k: k})
123 // multiply walks the optimal operation tree found by optimize,
124 // leaving the final result in the stack. It returns the
125 // product, which may be copied but should be returned to
126 // the workspace pool.
127 func (p *multiplier) multiply() *Dense {
128 result, _ := p.multiplySubchain(0, len(p.factors)-1)
129 if debugProductWalk {
130 r, c := result.Dims()
131 fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
133 return result.(*Dense)
136 func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
138 return p.factors[i], false
141 a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
142 b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
147 // Panic with a string since this
148 // is not a user-facing panic.
149 panic(ErrShape.Error())
152 if debugProductWalk {
153 fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
154 i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
157 r := getWorkspace(ar, bc, false)
160 putWorkspace(a.(*Dense))
163 putWorkspace(b.(*Dense))
169 k int // is the chain subdivision index.
170 cost int // cost is the cost of the operation.
173 // table is a row major n×n dynamic programming table.
179 func newTable(n int) table {
180 return table{n: n, entries: make([]entry, n*n)}
183 func (t table) at(i, j int) entry { return t.entries[i*t.n+j] }
184 func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
188 func (r result) String() string {
190 return " (popped result)"