+++ /dev/null
-// Copyright ©2015 The Gonum Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package mat
-
-import "fmt"
-
-// Product calculates the product of the given factors and places the result in
-// the receiver. The order of multiplication operations is optimized to minimize
-// the number of floating point operations on the basis that all matrix
-// multiplications are general.
-func (m *Dense) Product(factors ...Matrix) {
- // The operation order optimisation is the naive O(n^3) dynamic
- // programming approach and does not take into consideration
- // finer-grained optimisations that might be available.
- //
- // TODO(kortschak) Consider using the O(nlogn) or O(mlogn)
- // algorithms that are available. e.g.
- //
- // e.g. http://www.jofcis.com/publishedpapers/2014_10_10_4299_4306.pdf
- //
- // In the case that this is replaced, retain this code in
- // tests to compare against.
-
- r, c := m.Dims()
- switch len(factors) {
- case 0:
- if r != 0 || c != 0 {
- panic(ErrShape)
- }
- return
- case 1:
- m.reuseAs(factors[0].Dims())
- m.Copy(factors[0])
- return
- case 2:
- // Don't do work that we know the answer to.
- m.Mul(factors[0], factors[1])
- return
- }
-
- p := newMultiplier(m, factors)
- p.optimize()
- result := p.multiply()
- m.reuseAs(result.Dims())
- m.Copy(result)
- putWorkspace(result)
-}
-
-// debugProductWalk enables debugging output for Product.
-const debugProductWalk = false
-
-// multiplier performs operation order optimisation and tree traversal.
-type multiplier struct {
- // factors is the ordered set of
- // factors to multiply.
- factors []Matrix
- // dims is the chain of factor
- // dimensions.
- dims []int
-
- // table contains the dynamic
- // programming costs and subchain
- // division indices.
- table table
-}
-
-func newMultiplier(m *Dense, factors []Matrix) *multiplier {
- // Check size early, but don't yet
- // allocate data for m.
- r, c := m.Dims()
- fr, fc := factors[0].Dims() // newMultiplier is only called with len(factors) > 2.
- if !m.IsZero() {
- if fr != r {
- panic(ErrShape)
- }
- if _, lc := factors[len(factors)-1].Dims(); lc != c {
- panic(ErrShape)
- }
- }
-
- dims := make([]int, len(factors)+1)
- dims[0] = r
- dims[len(dims)-1] = c
- pc := fc
- for i, f := range factors[1:] {
- cr, cc := f.Dims()
- dims[i+1] = cr
- if pc != cr {
- panic(ErrShape)
- }
- pc = cc
- }
-
- return &multiplier{
- factors: factors,
- dims: dims,
- table: newTable(len(factors)),
- }
-}
-
-// optimize determines an optimal matrix multiply operation order.
-func (p *multiplier) optimize() {
- if debugProductWalk {
- fmt.Printf("chain dims: %v\n", p.dims)
- }
- const maxInt = int(^uint(0) >> 1)
- for f := 1; f < len(p.factors); f++ {
- for i := 0; i < len(p.factors)-f; i++ {
- j := i + f
- p.table.set(i, j, entry{cost: maxInt})
- for k := i; k < j; k++ {
- cost := p.table.at(i, k).cost + p.table.at(k+1, j).cost + p.dims[i]*p.dims[k+1]*p.dims[j+1]
- if cost < p.table.at(i, j).cost {
- p.table.set(i, j, entry{cost: cost, k: k})
- }
- }
- }
- }
-}
-
-// multiply walks the optimal operation tree found by optimize,
-// leaving the final result in the stack. It returns the
-// product, which may be copied but should be returned to
-// the workspace pool.
-func (p *multiplier) multiply() *Dense {
- result, _ := p.multiplySubchain(0, len(p.factors)-1)
- if debugProductWalk {
- r, c := result.Dims()
- fmt.Printf("\tpop result (%d×%d) cost=%d\n", r, c, p.table.at(0, len(p.factors)-1).cost)
- }
- return result.(*Dense)
-}
-
-func (p *multiplier) multiplySubchain(i, j int) (m Matrix, intermediate bool) {
- if i == j {
- return p.factors[i], false
- }
-
- a, aTmp := p.multiplySubchain(i, p.table.at(i, j).k)
- b, bTmp := p.multiplySubchain(p.table.at(i, j).k+1, j)
-
- ar, ac := a.Dims()
- br, bc := b.Dims()
- if ac != br {
- // Panic with a string since this
- // is not a user-facing panic.
- panic(ErrShape.Error())
- }
-
- if debugProductWalk {
- fmt.Printf("\tpush f[%d] (%d×%d)%s * f[%d] (%d×%d)%s\n",
- i, ar, ac, result(aTmp), j, br, bc, result(bTmp))
- }
-
- r := getWorkspace(ar, bc, false)
- r.Mul(a, b)
- if aTmp {
- putWorkspace(a.(*Dense))
- }
- if bTmp {
- putWorkspace(b.(*Dense))
- }
- return r, true
-}
-
-type entry struct {
- k int // is the chain subdivision index.
- cost int // cost is the cost of the operation.
-}
-
-// table is a row major n×n dynamic programming table.
-type table struct {
- n int
- entries []entry
-}
-
-func newTable(n int) table {
- return table{n: n, entries: make([]entry, n*n)}
-}
-
-func (t table) at(i, j int) entry { return t.entries[i*t.n+j] }
-func (t table) set(i, j int, e entry) { t.entries[i*t.n+j] = e }
-
-type result bool
-
-func (r result) String() string {
- if r {
- return " (popped result)"
- }
- return ""
-}