1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
11 "golang.org/x/exp/rand"
14 type dims struct{ r, c int }
16 var productTests = []struct {
24 factors: []dims{{3, 4}},
30 factors: []dims{{2, 4}},
36 factors: []dims{{10, 30}, {30, 5}, {5, 60}},
37 product: dims{10, 60},
42 factors: []dims{{100, 30}, {30, 5}, {5, 60}},
43 product: dims{10, 60},
48 factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
49 product: dims{60, 10},
54 factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
55 product: dims{60, 10},
60 factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
73 product: dims{60, 10},
78 product: dims{60, 10},
83 product: dims{60, 10},
88 product: dims{60, 10},
93 func TestProduct(t *testing.T) {
94 for _, test := range productTests {
95 dimensions := test.factors
96 if dimensions == nil && test.n > 0 {
97 dimensions = make([]dims, test.n)
98 for i := range dimensions {
100 dimensions[i].r = dimensions[i-1].c
102 dimensions[i].c = rand.Intn(50) + 1
104 dimensions[0].r = test.product.r
105 dimensions[test.n-1].c = test.product.c
107 factors := make([]Matrix, test.n)
108 for i, d := range dimensions {
109 data := make([]float64, d.r*d.c)
110 for i := range data {
111 data[i] = rand.Float64()
113 factors[i] = NewDense(d.r, d.c, data)
119 for i, b := range factors {
124 a, want = want, &Dense{}
129 got := NewDense(test.product.r, test.product.c, nil)
130 panicked, message := panics(func() {
131 got.Product(factors...)
135 t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
136 dimensions, test.product)
140 t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
141 message, dimensions, test.product)
145 if len(factors) > 0 {
146 p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
148 gotCost := p.table.at(0, len(factors)-1).cost
149 expr, wantCost, ok := bestExpressionFor(dimensions)
151 t.Fatal("unexpected number of expressions in brute force expression search")
153 if gotCost != wantCost {
154 t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
155 dimensions, got, want, expr)
159 if !EqualApprox(got, want, 1e-14) {
160 t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
165 // node is a subexpression node.
171 func (n *node) String() string {
172 if n.left == nil || n.right == nil {
173 rows, cols := n.shape()
174 return fmt.Sprintf("[%d×%d]", rows, cols)
176 rows, cols := n.shape()
177 return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
180 // shape returns the dimensions of the result of the subexpression.
181 func (n *node) shape() (rows, cols int) {
182 if n.left == nil || n.right == nil {
185 rows, _ = n.left.shape()
186 _, cols = n.right.shape()
190 // cost returns the cost to evaluate the subexpression.
191 func (n *node) cost() int {
192 if n.left == nil || n.right == nil {
195 lr, lc := n.left.shape()
196 _, rc := n.right.shape()
197 return lr*lc*rc + n.left.cost() + n.right.cost()
200 // expressionsFor returns a channel that can be used to iterate over all
201 // expressions of the given factor dimensions.
202 func expressionsFor(factors []dims) chan *node {
203 if len(factors) == 1 {
204 c := make(chan *node, 1)
205 c <- &node{dims: factors[0]}
209 c := make(chan *node)
211 for i := 1; i < len(factors); i++ {
212 for left := range expressionsFor(factors[:i]) {
213 for right := range expressionsFor(factors[i:]) {
214 c <- &node{left: left, right: right}
223 // catalan returns the nth 0-based Catalan number.
224 func catalan(n int) int {
226 for k := n + 1; k < 2*n+1; k++ {
229 for k := 2; k < n+2; k++ {
235 // bestExpressonFor returns the lowest cost expression for the given expression
236 // factor dimensions, the cost of the expression and whether the number of
237 // expressions searched matches the Catalan number for the number of factors.
238 func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
239 const maxInt = int(^uint(0) >> 1)
243 for exp := range expressionsFor(factors) {
251 return best, min, n == catalan(len(factors)-1)