+++ /dev/null
-// Copyright ©2015 The Gonum Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package mat
-
-import (
- "fmt"
- "testing"
-
- "golang.org/x/exp/rand"
-)
-
-type dims struct{ r, c int }
-
-var productTests = []struct {
- n int
- factors []dims
- product dims
- panics bool
-}{
- {
- n: 1,
- factors: []dims{{3, 4}},
- product: dims{3, 4},
- panics: false,
- },
- {
- n: 1,
- factors: []dims{{2, 4}},
- product: dims{3, 4},
- panics: true,
- },
- {
- n: 3,
- factors: []dims{{10, 30}, {30, 5}, {5, 60}},
- product: dims{10, 60},
- panics: false,
- },
- {
- n: 3,
- factors: []dims{{100, 30}, {30, 5}, {5, 60}},
- product: dims{10, 60},
- panics: true,
- },
- {
- n: 7,
- factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
- product: dims{60, 10},
- panics: false,
- },
- {
- n: 7,
- factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
- product: dims{60, 10},
- panics: true,
- },
- {
- n: 3,
- factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
- product: dims{1, 2},
- panics: false,
- },
-
- // Random chains.
- {
- n: 0,
- product: dims{0, 0},
- panics: false,
- },
- {
- n: 2,
- product: dims{60, 10},
- panics: false,
- },
- {
- n: 3,
- product: dims{60, 10},
- panics: false,
- },
- {
- n: 4,
- product: dims{60, 10},
- panics: false,
- },
- {
- n: 10,
- product: dims{60, 10},
- panics: false,
- },
-}
-
-func TestProduct(t *testing.T) {
- for _, test := range productTests {
- dimensions := test.factors
- if dimensions == nil && test.n > 0 {
- dimensions = make([]dims, test.n)
- for i := range dimensions {
- if i != 0 {
- dimensions[i].r = dimensions[i-1].c
- }
- dimensions[i].c = rand.Intn(50) + 1
- }
- dimensions[0].r = test.product.r
- dimensions[test.n-1].c = test.product.c
- }
- factors := make([]Matrix, test.n)
- for i, d := range dimensions {
- data := make([]float64, d.r*d.c)
- for i := range data {
- data[i] = rand.Float64()
- }
- factors[i] = NewDense(d.r, d.c, data)
- }
-
- want := &Dense{}
- if !test.panics {
- a := &Dense{}
- for i, b := range factors {
- if i == 0 {
- want.Clone(b)
- continue
- }
- a, want = want, &Dense{}
- want.Mul(a, b)
- }
- }
-
- got := NewDense(test.product.r, test.product.c, nil)
- panicked, message := panics(func() {
- got.Product(factors...)
- })
- if test.panics {
- if !panicked {
- t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
- dimensions, test.product)
- }
- continue
- } else if panicked {
- t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
- message, dimensions, test.product)
- continue
- }
-
- if len(factors) > 0 {
- p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
- p.optimize()
- gotCost := p.table.at(0, len(factors)-1).cost
- expr, wantCost, ok := bestExpressionFor(dimensions)
- if !ok {
- t.Fatal("unexpected number of expressions in brute force expression search")
- }
- if gotCost != wantCost {
- t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
- dimensions, got, want, expr)
- }
- }
-
- if !EqualApprox(got, want, 1e-14) {
- t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
- }
- }
-}
-
-// node is a subexpression node.
-type node struct {
- dims
- left, right *node
-}
-
-func (n *node) String() string {
- if n.left == nil || n.right == nil {
- rows, cols := n.shape()
- return fmt.Sprintf("[%d×%d]", rows, cols)
- }
- rows, cols := n.shape()
- return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
-}
-
-// shape returns the dimensions of the result of the subexpression.
-func (n *node) shape() (rows, cols int) {
- if n.left == nil || n.right == nil {
- return n.r, n.c
- }
- rows, _ = n.left.shape()
- _, cols = n.right.shape()
- return rows, cols
-}
-
-// cost returns the cost to evaluate the subexpression.
-func (n *node) cost() int {
- if n.left == nil || n.right == nil {
- return 0
- }
- lr, lc := n.left.shape()
- _, rc := n.right.shape()
- return lr*lc*rc + n.left.cost() + n.right.cost()
-}
-
-// expressionsFor returns a channel that can be used to iterate over all
-// expressions of the given factor dimensions.
-func expressionsFor(factors []dims) chan *node {
- if len(factors) == 1 {
- c := make(chan *node, 1)
- c <- &node{dims: factors[0]}
- close(c)
- return c
- }
- c := make(chan *node)
- go func() {
- for i := 1; i < len(factors); i++ {
- for left := range expressionsFor(factors[:i]) {
- for right := range expressionsFor(factors[i:]) {
- c <- &node{left: left, right: right}
- }
- }
- }
- close(c)
- }()
- return c
-}
-
-// catalan returns the nth 0-based Catalan number.
-func catalan(n int) int {
- p := 1
- for k := n + 1; k < 2*n+1; k++ {
- p *= k
- }
- for k := 2; k < n+2; k++ {
- p /= k
- }
- return p
-}
-
-// bestExpressonFor returns the lowest cost expression for the given expression
-// factor dimensions, the cost of the expression and whether the number of
-// expressions searched matches the Catalan number for the number of factors.
-func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
- const maxInt = int(^uint(0) >> 1)
- min := maxInt
- var best *node
- var n int
- for exp := range expressionsFor(factors) {
- n++
- cost := exp.cost()
- if cost < min {
- min = cost
- best = exp
- }
- }
- return best, min, n == catalan(len(factors)-1)
-}