OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / product_test.go
1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "fmt"
9         "testing"
10
11         "golang.org/x/exp/rand"
12 )
13
14 type dims struct{ r, c int }
15
16 var productTests = []struct {
17         n       int
18         factors []dims
19         product dims
20         panics  bool
21 }{
22         {
23                 n:       1,
24                 factors: []dims{{3, 4}},
25                 product: dims{3, 4},
26                 panics:  false,
27         },
28         {
29                 n:       1,
30                 factors: []dims{{2, 4}},
31                 product: dims{3, 4},
32                 panics:  true,
33         },
34         {
35                 n:       3,
36                 factors: []dims{{10, 30}, {30, 5}, {5, 60}},
37                 product: dims{10, 60},
38                 panics:  false,
39         },
40         {
41                 n:       3,
42                 factors: []dims{{100, 30}, {30, 5}, {5, 60}},
43                 product: dims{10, 60},
44                 panics:  true,
45         },
46         {
47                 n:       7,
48                 factors: []dims{{60, 5}, {5, 5}, {5, 4}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
49                 product: dims{60, 10},
50                 panics:  false,
51         },
52         {
53                 n:       7,
54                 factors: []dims{{60, 5}, {5, 5}, {5, 400}, {4, 10}, {10, 22}, {22, 45}, {45, 10}},
55                 product: dims{60, 10},
56                 panics:  true,
57         },
58         {
59                 n:       3,
60                 factors: []dims{{1, 1000}, {1000, 2}, {2, 2}},
61                 product: dims{1, 2},
62                 panics:  false,
63         },
64
65         // Random chains.
66         {
67                 n:       0,
68                 product: dims{0, 0},
69                 panics:  false,
70         },
71         {
72                 n:       2,
73                 product: dims{60, 10},
74                 panics:  false,
75         },
76         {
77                 n:       3,
78                 product: dims{60, 10},
79                 panics:  false,
80         },
81         {
82                 n:       4,
83                 product: dims{60, 10},
84                 panics:  false,
85         },
86         {
87                 n:       10,
88                 product: dims{60, 10},
89                 panics:  false,
90         },
91 }
92
93 func TestProduct(t *testing.T) {
94         for _, test := range productTests {
95                 dimensions := test.factors
96                 if dimensions == nil && test.n > 0 {
97                         dimensions = make([]dims, test.n)
98                         for i := range dimensions {
99                                 if i != 0 {
100                                         dimensions[i].r = dimensions[i-1].c
101                                 }
102                                 dimensions[i].c = rand.Intn(50) + 1
103                         }
104                         dimensions[0].r = test.product.r
105                         dimensions[test.n-1].c = test.product.c
106                 }
107                 factors := make([]Matrix, test.n)
108                 for i, d := range dimensions {
109                         data := make([]float64, d.r*d.c)
110                         for i := range data {
111                                 data[i] = rand.Float64()
112                         }
113                         factors[i] = NewDense(d.r, d.c, data)
114                 }
115
116                 want := &Dense{}
117                 if !test.panics {
118                         a := &Dense{}
119                         for i, b := range factors {
120                                 if i == 0 {
121                                         want.Clone(b)
122                                         continue
123                                 }
124                                 a, want = want, &Dense{}
125                                 want.Mul(a, b)
126                         }
127                 }
128
129                 got := NewDense(test.product.r, test.product.c, nil)
130                 panicked, message := panics(func() {
131                         got.Product(factors...)
132                 })
133                 if test.panics {
134                         if !panicked {
135                                 t.Errorf("fail to panic with product chain dimensions: %+v result dimension: %+v",
136                                         dimensions, test.product)
137                         }
138                         continue
139                 } else if panicked {
140                         t.Errorf("unexpected panic %q with product chain dimensions: %+v result dimension: %+v",
141                                 message, dimensions, test.product)
142                         continue
143                 }
144
145                 if len(factors) > 0 {
146                         p := newMultiplier(NewDense(test.product.r, test.product.c, nil), factors)
147                         p.optimize()
148                         gotCost := p.table.at(0, len(factors)-1).cost
149                         expr, wantCost, ok := bestExpressionFor(dimensions)
150                         if !ok {
151                                 t.Fatal("unexpected number of expressions in brute force expression search")
152                         }
153                         if gotCost != wantCost {
154                                 t.Errorf("unexpected cost for chain dimensions: %+v got: %v want: %v\n%s",
155                                         dimensions, got, want, expr)
156                         }
157                 }
158
159                 if !EqualApprox(got, want, 1e-14) {
160                         t.Errorf("unexpected result from product chain dimensions: %+v", dimensions)
161                 }
162         }
163 }
164
165 // node is a subexpression node.
166 type node struct {
167         dims
168         left, right *node
169 }
170
171 func (n *node) String() string {
172         if n.left == nil || n.right == nil {
173                 rows, cols := n.shape()
174                 return fmt.Sprintf("[%d×%d]", rows, cols)
175         }
176         rows, cols := n.shape()
177         return fmt.Sprintf("(%s * %s):[%d×%d]", n.left, n.right, rows, cols)
178 }
179
180 // shape returns the dimensions of the result of the subexpression.
181 func (n *node) shape() (rows, cols int) {
182         if n.left == nil || n.right == nil {
183                 return n.r, n.c
184         }
185         rows, _ = n.left.shape()
186         _, cols = n.right.shape()
187         return rows, cols
188 }
189
190 // cost returns the cost to evaluate the subexpression.
191 func (n *node) cost() int {
192         if n.left == nil || n.right == nil {
193                 return 0
194         }
195         lr, lc := n.left.shape()
196         _, rc := n.right.shape()
197         return lr*lc*rc + n.left.cost() + n.right.cost()
198 }
199
200 // expressionsFor returns a channel that can be used to iterate over all
201 // expressions of the given factor dimensions.
202 func expressionsFor(factors []dims) chan *node {
203         if len(factors) == 1 {
204                 c := make(chan *node, 1)
205                 c <- &node{dims: factors[0]}
206                 close(c)
207                 return c
208         }
209         c := make(chan *node)
210         go func() {
211                 for i := 1; i < len(factors); i++ {
212                         for left := range expressionsFor(factors[:i]) {
213                                 for right := range expressionsFor(factors[i:]) {
214                                         c <- &node{left: left, right: right}
215                                 }
216                         }
217                 }
218                 close(c)
219         }()
220         return c
221 }
222
223 // catalan returns the nth 0-based Catalan number.
224 func catalan(n int) int {
225         p := 1
226         for k := n + 1; k < 2*n+1; k++ {
227                 p *= k
228         }
229         for k := 2; k < n+2; k++ {
230                 p /= k
231         }
232         return p
233 }
234
235 // bestExpressonFor returns the lowest cost expression for the given expression
236 // factor dimensions, the cost of the expression and whether the number of
237 // expressions searched matches the Catalan number for the number of factors.
238 func bestExpressionFor(factors []dims) (exp *node, cost int, ok bool) {
239         const maxInt = int(^uint(0) >> 1)
240         min := maxInt
241         var best *node
242         var n int
243         for exp := range expressionsFor(factors) {
244                 n++
245                 cost := exp.cost()
246                 if cost < min {
247                         min = cost
248                         best = exp
249                 }
250         }
251         return best, min, n == catalan(len(factors)-1)
252 }