OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / product.go
diff --git a/vendor/gonum.org/v1/gonum/mat/product.go b/vendor/gonum.org/v1/gonum/mat/product.go
new file mode 100644 (file)
index 0000000..08424ab
--- /dev/null
@@ -0,0 +1,193 @@
+// Copyright ©2015 The Gonum Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package mat
+
+import "fmt"
+
+// Product calculates the product of the given factors and places the result in
+// the receiver. The order of multiplication operations is optimized to minimize
+// the number of floating point operations on the basis that all matrix
+// multiplications are general.
+func (m *Dense) Product(factors ...Matrix) {
+       // The operation order optimisation is the naive O(n^3) dynamic
+       // programming approach and does not take into consideration
+       // finer-grained optimisations that might be available.
+       //
+       // TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
+       // algorithms that are available. e.g.
+       //
+       // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
+       //
+       // In the case that this is replaced, retain this code in
+       // tests to compare against.
+
+       r, c := m.Dims()
+       switch len(factors) {
+       case 0:
+               if r != 0 || c != 0 {
+                       panic(ErrShape)
+               }
+               return
+       case 1:
+               m.reuseAs(factors[0].Dims())
+               m.Copy(factors[0])
+               return
+       case 2:
+               // Don't do work that we know the answer to.
+               m.Mul(factors[0], factors[1])
+               return
+       }
+
+       p := newMultiplier(m, factors)
+       p.optimize()
+       result := p.multiply()
+       m.reuseAs(result.Dims())
+       m.Copy(result)
+       putWorkspace(result)
+}
+
+// debugProductWalk enables debugging output for Product.
+const debugProductWalk = false
+
+// multiplier performs operation order optimisation and tree traversal.
+type multiplier struct {
+       // factors is the ordered set of
+       // factors to multiply.
+       factors []Matrix
+       // dims is the chain of factor
+       // dimensions.
+       dims []int
+
+       // table contains the dynamic
+       // programming costs and subchain
+       // division indices.
+       table table
+}
+
+func newMultiplier(m *Dense, factors []Matrix) *multiplier {
+       // Check size early, but don't yet
+       // allocate data for m.
+       r, c := m.Dims()
+       fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
+       if !m.IsZero() {
+               if fr != r {
+                       panic(ErrShape)
+               }
+               if _, lc := factors[len(factors)-1].Dims(); lc != c {
+                       panic(ErrShape)
+               }
+       }
+
+       dims := make([]int, len(factors)+1)
+       dims[0] = r
+       dims[len(dims)-1] = c
+       pc := fc
+       for i, f := range factors[1:] {
+               cr, cc := f.Dims()
+               dims[i+1] = cr
+               if pc != cr {
+                       panic(ErrShape)
+               }
+               pc = cc
+       }
+
+       return &multiplier{
+               factors: factors,
+               dims:    dims,
+               table:   newTable(len(factors)),
+       }
+}
+
+// optimize determines an optimal matrix multiply operation order.
+func (p *multiplier) optimize() {
+       if debugProductWalk {
+               fmt.Printf("chain dims: %v\n", p.dims)
+       }
+       const maxInt = int(^uint(0) >> 1)
+       for f := 1; f < len(p.factors); f++ {
+               for i := 0; i < len(p.factors)-f; i++ {
+                       j := i + f
+                       p.table.set(i, j, entry{cost: maxInt})
+                       for k := i; k < j; k++ {
+                               cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
+                               if cost < p.table.at(i, j).cost {
+                                       p.table.set(i, j, entry{cost: cost, k: k})
+                               }
+                       }
+               }
+       }
+}
+
+// multiply walks the optimal operation tree found by optimize,
+// leaving the final result in the stack. It returns the
+// product, which may be copied but should be returned to
+// the workspace pool.
+func (p *multiplier) multiply() *Dense {
+       result, _ := p.multiplySubchain(0, len(p.factors)-1)
+       if debugProductWalk {
+               r, c := result.Dims()
+               fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
+       }
+       return result.(*Dense)
+}
+
+func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
+       if i == j {
+               return p.factors[i], false
+       }
+
+       a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
+       b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
+
+       ar, ac := a.Dims()
+       br, bc := b.Dims()
+       if ac != br {
+               // Panic with a string since this
+               // is not a user-facing panic.
+               panic(ErrShape.Error())
+       }
+
+       if debugProductWalk {
+               fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
+                       i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
+       }
+
+       r := getWorkspace(ar, bc, false)
+       r.Mul(a, b)
+       if aTmp {
+               putWorkspace(a.(*Dense))
+       }
+       if bTmp {
+               putWorkspace(b.(*Dense))
+       }
+       return r, true
+}
+
+type entry struct {
+       k    int // is the chain subdivision index.
+       cost int // cost is the cost of the operation.
+}
+
+// table is a row major n×n dynamic programming table.
+type table struct {
+       n       int
+       entries []entry
+}
+
+func newTable(n int) table {
+       return table{n: n, entries: make([]entry, n*n)}
+}
+
+func (t table) at(i, j int) entry     { return t.entries[i*t.n+j] }
+func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
+
+type result bool
+
+func (r result) String() string {
+       if r {
+               return " (popped result)"
+       }
+       return ""
+}