OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / lapack / testlapack / dlapmt.go
1 // Copyright ©2017 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package testlapack
6
7 import (
8         "fmt"
9         "testing"
10
11         "gonum.org/v1/gonum/blas/blas64"
12 )
13
14 type Dlapmter interface {
15         Dlapmt(forward bool, m, n int, x []float64, ldx int, k []int)
16 }
17
18 func DlapmtTest(t *testing.T, impl Dlapmter) {
19         for ti, test := range []struct {
20                 forward bool
21                 k       []int
22
23                 want blas64.General
24         }{
25                 {
26                         forward: true, k: []int{0, 1, 2},
27                         want: blas64.General{
28                                 Rows:   4,
29                                 Cols:   3,
30                                 Stride: 3,
31                                 Data: []float64{
32                                         1, 2, 3,
33                                         4, 5, 6,
34                                         7, 8, 9,
35                                         10, 11, 12,
36                                 },
37                         },
38                 },
39                 {
40                         forward: false, k: []int{0, 1, 2},
41                         want: blas64.General{
42                                 Rows:   4,
43                                 Cols:   3,
44                                 Stride: 3,
45                                 Data: []float64{
46                                         1, 2, 3,
47                                         4, 5, 6,
48                                         7, 8, 9,
49                                         10, 11, 12,
50                                 },
51                         },
52                 },
53                 {
54                         forward: true, k: []int{1, 2, 0},
55                         want: blas64.General{
56                                 Rows:   4,
57                                 Cols:   3,
58                                 Stride: 3,
59                                 Data: []float64{
60                                         2, 3, 1,
61                                         5, 6, 4,
62                                         8, 9, 7,
63                                         11, 12, 10,
64                                 },
65                         },
66                 },
67                 {
68                         forward: false, k: []int{1, 2, 0},
69                         want: blas64.General{
70                                 Rows:   4,
71                                 Cols:   3,
72                                 Stride: 3,
73                                 Data: []float64{
74                                         3, 1, 2,
75                                         6, 4, 5,
76                                         9, 7, 8,
77                                         12, 10, 11,
78                                 },
79                         },
80                 },
81         } {
82                 m := test.want.Rows
83                 n := test.want.Cols
84                 if len(test.k) != n {
85                         panic("bad length of k")
86                 }
87
88                 for _, extra := range []int{0, 11} {
89                         x := zeros(m, n, n+extra)
90                         c := 1
91                         for i := 0; i < m; i++ {
92                                 for j := 0; j < n; j++ {
93                                         x.Data[i*x.Stride+j] = float64(c)
94                                         c++
95                                 }
96                         }
97
98                         k := make([]int, len(test.k))
99                         copy(k, test.k)
100
101                         impl.Dlapmt(test.forward, m, n, x.Data, x.Stride, k)
102
103                         prefix := fmt.Sprintf("Case %v (forward=%t,m=%v,n=%v,extra=%v)", ti, test.forward, m, n, extra)
104                         if !generalOutsideAllNaN(x) {
105                                 t.Errorf("%v: out-of-range write to X", prefix)
106                         }
107
108                         if !equalApproxGeneral(x, test.want, 0) {
109                                 t.Errorf("%v: unexpected X\n%v\n%v", prefix, x, test.want)
110                         }
111                 }
112         }
113 }