1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "golang.org/x/exp/rand"
12 "gonum.org/v1/gonum/blas"
13 "gonum.org/v1/gonum/blas/blas64"
14 "gonum.org/v1/gonum/floats"
17 // TODO: Need to add tests where one is overwritten.
18 func TestMulTypes(t *testing.T) {
19 for _, test := range []struct {
88 // Generate random matrices
89 avec := make([]float64, ar*ac)
91 a := NewDense(ar, ac, avec)
93 bvec := make([]float64, br*bc)
96 b := NewDense(br, bc, bvec)
98 // Check that it panics if it is supposed to
100 c := NewDense(0, 0, nil)
106 t.Errorf("Mul did not panic with dimension mismatch")
111 cvec := make([]float64, ar*bc)
113 // Get correct matrix multiply answer from blas64.Gemm
114 blas64.Gemm(blas.NoTrans, blas.NoTrans,
116 0, blas64.General{Rows: ar, Cols: bc, Stride: bc, Data: cvec},
119 avecCopy := append([]float64{}, avec...)
120 bvecCopy := append([]float64{}, bvec...)
121 cvecCopy := append([]float64{}, cvec...)
123 acomp := matComp{r: ar, c: ac, data: avecCopy}
124 bcomp := matComp{r: br, c: bc, data: bvecCopy}
125 ccomp := matComp{r: ar, c: bc, data: cvecCopy}
127 // Do normal multiply with empty dense
128 d := NewDense(0, 0, nil)
130 testMul(t, a, b, d, acomp, bcomp, ccomp, false, "zero receiver")
132 // Normal multiply with existing receiver
133 c := NewDense(ar, bc, cvec)
135 testMul(t, a, b, c, acomp, bcomp, ccomp, false, "existing receiver")
137 // Cast a as a basic matrix
138 am := (*basicMatrix)(a)
139 bm := (*basicMatrix)(b)
141 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is zero")
143 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is zero")
145 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is zero")
147 testMul(t, am, b, d, acomp, bcomp, ccomp, true, "a is basic, receiver is full")
149 testMul(t, a, bm, d, acomp, bcomp, ccomp, true, "b is basic, receiver is full")
151 testMul(t, am, bm, d, acomp, bcomp, ccomp, true, "both basic, receiver is full")
155 func randomSlice(s []float64) {
157 s[i] = rand.NormFloat64()
161 type matComp struct {
166 func testMul(t *testing.T, a, b Matrix, c *Dense, acomp, bcomp, ccomp matComp, cvecApprox bool, name string) {
169 switch t := a.(type) {
177 switch t := b.(type) {
184 if !denseEqual(aDense, acomp) {
185 t.Errorf("a changed unexpectedly for %v", name)
187 if !denseEqual(bDense, bcomp) {
188 t.Errorf("b changed unexpectedly for %v", name)
191 if !denseEqualApprox(c, ccomp, 1e-14) {
192 t.Errorf("mul answer not within tol for %v", name)
197 if !denseEqual(c, ccomp) {
198 t.Errorf("mul answer not equal for %v", name)
202 type basicMatrix Dense
204 func (m *basicMatrix) At(r, c int) float64 {
205 return (*Dense)(m).At(r, c)
208 func (m *basicMatrix) Dims() (r, c int) {
209 return (*Dense)(m).Dims()
212 func (m *basicMatrix) T() Matrix {
216 type basicSymmetric SymDense
218 var _ Symmetric = &basicSymmetric{}
220 func (m *basicSymmetric) At(r, c int) float64 {
221 return (*SymDense)(m).At(r, c)
224 func (m *basicSymmetric) Dims() (r, c int) {
225 return (*SymDense)(m).Dims()
228 func (m *basicSymmetric) T() Matrix {
232 func (m *basicSymmetric) Symmetric() int {
233 return (*SymDense)(m).Symmetric()
236 type basicTriangular TriDense
238 func (m *basicTriangular) At(r, c int) float64 {
239 return (*TriDense)(m).At(r, c)
242 func (m *basicTriangular) Dims() (r, c int) {
243 return (*TriDense)(m).Dims()
246 func (m *basicTriangular) T() Matrix {
250 func (m *basicTriangular) Triangle() (int, TriKind) {
251 return (*TriDense)(m).Triangle()
254 func (m *basicTriangular) TTri() Triangular {
255 return TransposeTri{m}
258 func denseEqual(a *Dense, acomp matComp) bool {
266 if !floats.Equal(a.mat.Data, acomp.data) {
272 func denseEqualApprox(a *Dense, acomp matComp, tol float64) bool {
280 if !floats.EqualApprox(a.mat.Data, acomp.data, tol) {