OSDN Git Service

test (#52)
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / matrix.go
1 // Copyright ©2013 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "math"
9
10         "gonum.org/v1/gonum/blas"
11         "gonum.org/v1/gonum/blas/blas64"
12         "gonum.org/v1/gonum/floats"
13         "gonum.org/v1/gonum/lapack"
14         "gonum.org/v1/gonum/lapack/lapack64"
15 )
16
17 // Matrix is the basic matrix interface type.
18 type Matrix interface {
19         // Dims returns the dimensions of a Matrix.
20         Dims() (r, c int)
21
22         // At returns the value of a matrix element at row i, column j.
23         // It will panic if i or j are out of bounds for the matrix.
24         At(i, j int) float64
25
26         // T returns the transpose of the Matrix. Whether T returns a copy of the
27         // underlying data is implementation dependent.
28         // This method may be implemented using the Transpose type, which
29         // provides an implicit matrix transpose.
30         T() Matrix
31 }
32
33 var (
34         _ Matrix       = Transpose{}
35         _ Untransposer = Transpose{}
36 )
37
38 // Transpose is a type for performing an implicit matrix transpose. It implements
39 // the Matrix interface, returning values from the transpose of the matrix within.
40 type Transpose struct {
41         Matrix Matrix
42 }
43
44 // At returns the value of the element at row i and column j of the transposed
45 // matrix, that is, row j and column i of the Matrix field.
46 func (t Transpose) At(i, j int) float64 {
47         return t.Matrix.At(j, i)
48 }
49
50 // Dims returns the dimensions of the transposed matrix. The number of rows returned
51 // is the number of columns in the Matrix field, and the number of columns is
52 // the number of rows in the Matrix field.
53 func (t Transpose) Dims() (r, c int) {
54         c, r = t.Matrix.Dims()
55         return r, c
56 }
57
58 // T performs an implicit transpose by returning the Matrix field.
59 func (t Transpose) T() Matrix {
60         return t.Matrix
61 }
62
63 // Untranspose returns the Matrix field.
64 func (t Transpose) Untranspose() Matrix {
65         return t.Matrix
66 }
67
68 // Untransposer is a type that can undo an implicit transpose.
69 type Untransposer interface {
70         // Note: This interface is needed to unify all of the Transpose types. In
71         // the mat methods, we need to test if the Matrix has been implicitly
72         // transposed. If this is checked by testing for the specific Transpose type
73         // then the behavior will be different if the user uses T() or TTri() for a
74         // triangular matrix.
75
76         // Untranspose returns the underlying Matrix stored for the implicit transpose.
77         Untranspose() Matrix
78 }
79
80 // UntransposeBander is a type that can undo an implicit band transpose.
81 type UntransposeBander interface {
82         // Untranspose returns the underlying Banded stored for the implicit transpose.
83         UntransposeBand() Banded
84 }
85
86 // UntransposeTrier is a type that can undo an implicit triangular transpose.
87 type UntransposeTrier interface {
88         // Untranspose returns the underlying Triangular stored for the implicit transpose.
89         UntransposeTri() Triangular
90 }
91
92 // Mutable is a matrix interface type that allows elements to be altered.
93 type Mutable interface {
94         // Set alters the matrix element at row i, column j to v.
95         // It will panic if i or j are out of bounds for the matrix.
96         Set(i, j int, v float64)
97
98         Matrix
99 }
100
101 // A RowViewer can return a Vector reflecting a row that is backed by the matrix
102 // data. The Vector returned will have length equal to the number of columns.
103 type RowViewer interface {
104         RowView(i int) Vector
105 }
106
107 // A RawRowViewer can return a slice of float64 reflecting a row that is backed by the matrix
108 // data.
109 type RawRowViewer interface {
110         RawRowView(i int) []float64
111 }
112
113 // A ColViewer can return a Vector reflecting a column that is backed by the matrix
114 // data. The Vector returned will have length equal to the number of rows.
115 type ColViewer interface {
116         ColView(j int) Vector
117 }
118
119 // A RawColViewer can return a slice of float64 reflecting a column that is backed by the matrix
120 // data.
121 type RawColViewer interface {
122         RawColView(j int) []float64
123 }
124
125 // A Cloner can make a copy of a into the receiver, overwriting the previous value of the
126 // receiver. The clone operation does not make any restriction on shape and will not cause
127 // shadowing.
128 type Cloner interface {
129         Clone(a Matrix)
130 }
131
132 // A Reseter can reset the matrix so that it can be reused as the receiver of a dimensionally
133 // restricted operation. This is commonly used when the matrix is being used as a workspace
134 // or temporary matrix.
135 //
136 // If the matrix is a view, using the reset matrix may result in data corruption in elements
137 // outside the view.
138 type Reseter interface {
139         Reset()
140 }
141
142 // A Copier can make a copy of elements of a into the receiver. The submatrix copied
143 // starts at row and column 0 and has dimensions equal to the minimum dimensions of
144 // the two matrices. The number of row and columns copied is returned.
145 // Copy will copy from a source that aliases the receiver unless the source is transposed;
146 // an aliasing transpose copy will panic with the exception for a special case when
147 // the source data has a unitary increment or stride.
148 type Copier interface {
149         Copy(a Matrix) (r, c int)
150 }
151
152 // A Grower can grow the size of the represented matrix by the given number of rows and columns.
153 // Growing beyond the size given by the Caps method will result in the allocation of a new
154 // matrix and copying of the elements. If Grow is called with negative increments it will
155 // panic with ErrIndexOutOfRange.
156 type Grower interface {
157         Caps() (r, c int)
158         Grow(r, c int) Matrix
159 }
160
161 // A BandWidther represents a banded matrix and can return the left and right half-bandwidths, k1 and
162 // k2.
163 type BandWidther interface {
164         BandWidth() (k1, k2 int)
165 }
166
167 // A RawMatrixSetter can set the underlying blas64.General used by the receiver. There is no restriction
168 // on the shape of the receiver. Changes to the receiver's elements will be reflected in the blas64.General.Data.
169 type RawMatrixSetter interface {
170         SetRawMatrix(a blas64.General)
171 }
172
173 // A RawMatrixer can return a blas64.General representation of the receiver. Changes to the blas64.General.Data
174 // slice will be reflected in the original matrix, changes to the Rows, Cols and Stride fields will not.
175 type RawMatrixer interface {
176         RawMatrix() blas64.General
177 }
178
179 // A RawVectorer can return a blas64.Vector representation of the receiver. Changes to the blas64.Vector.Data
180 // slice will be reflected in the original matrix, changes to the Inc field will not.
181 type RawVectorer interface {
182         RawVector() blas64.Vector
183 }
184
185 // A NonZeroDoer can call a function for each non-zero element of the receiver.
186 // The parameters of the function are the element indices and its value.
187 type NonZeroDoer interface {
188         DoNonZero(func(i, j int, v float64))
189 }
190
191 // A RowNonZeroDoer can call a function for each non-zero element of a row of the receiver.
192 // The parameters of the function are the element indices and its value.
193 type RowNonZeroDoer interface {
194         DoRowNonZero(i int, fn func(i, j int, v float64))
195 }
196
197 // A ColNonZeroDoer can call a function for each non-zero element of a column of the receiver.
198 // The parameters of the function are the element indices and its value.
199 type ColNonZeroDoer interface {
200         DoColNonZero(j int, fn func(i, j int, v float64))
201 }
202
203 // TODO(btracey): Consider adding CopyCol/CopyRow if the behavior seems useful.
204 // TODO(btracey): Add in fast paths to Row/Col for the other concrete types
205 // (TriDense, etc.) as well as relevant interfaces (RowColer, RawRowViewer, etc.)
206
207 // Col copies the elements in the jth column of the matrix into the slice dst.
208 // The length of the provided slice must equal the number of rows, unless the
209 // slice is nil in which case a new slice is first allocated.
210 func Col(dst []float64, j int, a Matrix) []float64 {
211         r, c := a.Dims()
212         if j < 0 || j >= c {
213                 panic(ErrColAccess)
214         }
215         if dst == nil {
216                 dst = make([]float64, r)
217         } else {
218                 if len(dst) != r {
219                         panic(ErrColLength)
220                 }
221         }
222         aU, aTrans := untranspose(a)
223         if rm, ok := aU.(RawMatrixer); ok {
224                 m := rm.RawMatrix()
225                 if aTrans {
226                         copy(dst, m.Data[j*m.Stride:j*m.Stride+m.Cols])
227                         return dst
228                 }
229                 blas64.Copy(r,
230                         blas64.Vector{Inc: m.Stride, Data: m.Data[j:]},
231                         blas64.Vector{Inc: 1, Data: dst},
232                 )
233                 return dst
234         }
235         for i := 0; i < r; i++ {
236                 dst[i] = a.At(i, j)
237         }
238         return dst
239 }
240
241 // Row copies the elements in the ith row of the matrix into the slice dst.
242 // The length of the provided slice must equal the number of columns, unless the
243 // slice is nil in which case a new slice is first allocated.
244 func Row(dst []float64, i int, a Matrix) []float64 {
245         r, c := a.Dims()
246         if i < 0 || i >= r {
247                 panic(ErrColAccess)
248         }
249         if dst == nil {
250                 dst = make([]float64, c)
251         } else {
252                 if len(dst) != c {
253                         panic(ErrRowLength)
254                 }
255         }
256         aU, aTrans := untranspose(a)
257         if rm, ok := aU.(RawMatrixer); ok {
258                 m := rm.RawMatrix()
259                 if aTrans {
260                         blas64.Copy(c,
261                                 blas64.Vector{Inc: m.Stride, Data: m.Data[i:]},
262                                 blas64.Vector{Inc: 1, Data: dst},
263                         )
264                         return dst
265                 }
266                 copy(dst, m.Data[i*m.Stride:i*m.Stride+m.Cols])
267                 return dst
268         }
269         for j := 0; j < c; j++ {
270                 dst[j] = a.At(i, j)
271         }
272         return dst
273 }
274
275 // Cond returns the condition number of the given matrix under the given norm.
276 // The condition number must be based on the 1-norm, 2-norm or ∞-norm.
277 // Cond will panic with matrix.ErrShape if the matrix has zero size.
278 //
279 // BUG(btracey): The computation of the 1-norm and ∞-norm for non-square matrices
280 // is innacurate, although is typically the right order of magnitude. See
281 // https://github.com/xianyi/OpenBLAS/issues/636. While the value returned will
282 // change with the resolution of this bug, the result from Cond will match the
283 // condition number used internally.
284 func Cond(a Matrix, norm float64) float64 {
285         m, n := a.Dims()
286         if m == 0 || n == 0 {
287                 panic(ErrShape)
288         }
289         var lnorm lapack.MatrixNorm
290         switch norm {
291         default:
292                 panic("mat: bad norm value")
293         case 1:
294                 lnorm = lapack.MaxColumnSum
295         case 2:
296                 var svd SVD
297                 ok := svd.Factorize(a, SVDNone)
298                 if !ok {
299                         return math.Inf(1)
300                 }
301                 return svd.Cond()
302         case math.Inf(1):
303                 lnorm = lapack.MaxRowSum
304         }
305
306         if m == n {
307                 // Use the LU decomposition to compute the condition number.
308                 var lu LU
309                 lu.factorize(a, lnorm)
310                 return lu.Cond()
311         }
312         if m > n {
313                 // Use the QR factorization to compute the condition number.
314                 var qr QR
315                 qr.factorize(a, lnorm)
316                 return qr.Cond()
317         }
318         // Use the LQ factorization to compute the condition number.
319         var lq LQ
320         lq.factorize(a, lnorm)
321         return lq.Cond()
322 }
323
324 // Det returns the determinant of the matrix a. In many expressions using LogDet
325 // will be more numerically stable.
326 func Det(a Matrix) float64 {
327         det, sign := LogDet(a)
328         return math.Exp(det) * sign
329 }
330
331 // Dot returns the sum of the element-wise product of a and b.
332 // Dot panics if the matrix sizes are unequal.
333 func Dot(a, b Vector) float64 {
334         la := a.Len()
335         lb := b.Len()
336         if la != lb {
337                 panic(ErrShape)
338         }
339         if arv, ok := a.(RawVectorer); ok {
340                 if brv, ok := b.(RawVectorer); ok {
341                         return blas64.Dot(la, arv.RawVector(), brv.RawVector())
342                 }
343         }
344         var sum float64
345         for i := 0; i < la; i++ {
346                 sum += a.At(i, 0) * b.At(i, 0)
347         }
348         return sum
349 }
350
351 // Equal returns whether the matrices a and b have the same size
352 // and are element-wise equal.
353 func Equal(a, b Matrix) bool {
354         ar, ac := a.Dims()
355         br, bc := b.Dims()
356         if ar != br || ac != bc {
357                 return false
358         }
359         aU, aTrans := untranspose(a)
360         bU, bTrans := untranspose(b)
361         if rma, ok := aU.(RawMatrixer); ok {
362                 if rmb, ok := bU.(RawMatrixer); ok {
363                         ra := rma.RawMatrix()
364                         rb := rmb.RawMatrix()
365                         if aTrans == bTrans {
366                                 for i := 0; i < ra.Rows; i++ {
367                                         for j := 0; j < ra.Cols; j++ {
368                                                 if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
369                                                         return false
370                                                 }
371                                         }
372                                 }
373                                 return true
374                         }
375                         for i := 0; i < ra.Rows; i++ {
376                                 for j := 0; j < ra.Cols; j++ {
377                                         if ra.Data[i*ra.Stride+j] != rb.Data[j*rb.Stride+i] {
378                                                 return false
379                                         }
380                                 }
381                         }
382                         return true
383                 }
384         }
385         if rma, ok := aU.(RawSymmetricer); ok {
386                 if rmb, ok := bU.(RawSymmetricer); ok {
387                         ra := rma.RawSymmetric()
388                         rb := rmb.RawSymmetric()
389                         // Symmetric matrices are always upper and equal to their transpose.
390                         for i := 0; i < ra.N; i++ {
391                                 for j := i; j < ra.N; j++ {
392                                         if ra.Data[i*ra.Stride+j] != rb.Data[i*rb.Stride+j] {
393                                                 return false
394                                         }
395                                 }
396                         }
397                         return true
398                 }
399         }
400         if ra, ok := aU.(*VecDense); ok {
401                 if rb, ok := bU.(*VecDense); ok {
402                         // If the raw vectors are the same length they must either both be
403                         // transposed or both not transposed (or have length 1).
404                         for i := 0; i < ra.n; i++ {
405                                 if ra.mat.Data[i*ra.mat.Inc] != rb.mat.Data[i*rb.mat.Inc] {
406                                         return false
407                                 }
408                         }
409                         return true
410                 }
411         }
412         for i := 0; i < ar; i++ {
413                 for j := 0; j < ac; j++ {
414                         if a.At(i, j) != b.At(i, j) {
415                                 return false
416                         }
417                 }
418         }
419         return true
420 }
421
422 // EqualApprox returns whether the matrices a and b have the same size and contain all equal
423 // elements with tolerance for element-wise equality specified by epsilon. Matrices
424 // with non-equal shapes are not equal.
425 func EqualApprox(a, b Matrix, epsilon float64) bool {
426         ar, ac := a.Dims()
427         br, bc := b.Dims()
428         if ar != br || ac != bc {
429                 return false
430         }
431         aU, aTrans := untranspose(a)
432         bU, bTrans := untranspose(b)
433         if rma, ok := aU.(RawMatrixer); ok {
434                 if rmb, ok := bU.(RawMatrixer); ok {
435                         ra := rma.RawMatrix()
436                         rb := rmb.RawMatrix()
437                         if aTrans == bTrans {
438                                 for i := 0; i < ra.Rows; i++ {
439                                         for j := 0; j < ra.Cols; j++ {
440                                                 if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
441                                                         return false
442                                                 }
443                                         }
444                                 }
445                                 return true
446                         }
447                         for i := 0; i < ra.Rows; i++ {
448                                 for j := 0; j < ra.Cols; j++ {
449                                         if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[j*rb.Stride+i], epsilon, epsilon) {
450                                                 return false
451                                         }
452                                 }
453                         }
454                         return true
455                 }
456         }
457         if rma, ok := aU.(RawSymmetricer); ok {
458                 if rmb, ok := bU.(RawSymmetricer); ok {
459                         ra := rma.RawSymmetric()
460                         rb := rmb.RawSymmetric()
461                         // Symmetric matrices are always upper and equal to their transpose.
462                         for i := 0; i < ra.N; i++ {
463                                 for j := i; j < ra.N; j++ {
464                                         if !floats.EqualWithinAbsOrRel(ra.Data[i*ra.Stride+j], rb.Data[i*rb.Stride+j], epsilon, epsilon) {
465                                                 return false
466                                         }
467                                 }
468                         }
469                         return true
470                 }
471         }
472         if ra, ok := aU.(*VecDense); ok {
473                 if rb, ok := bU.(*VecDense); ok {
474                         // If the raw vectors are the same length they must either both be
475                         // transposed or both not transposed (or have length 1).
476                         for i := 0; i < ra.n; i++ {
477                                 if !floats.EqualWithinAbsOrRel(ra.mat.Data[i*ra.mat.Inc], rb.mat.Data[i*rb.mat.Inc], epsilon, epsilon) {
478                                         return false
479                                 }
480                         }
481                         return true
482                 }
483         }
484         for i := 0; i < ar; i++ {
485                 for j := 0; j < ac; j++ {
486                         if !floats.EqualWithinAbsOrRel(a.At(i, j), b.At(i, j), epsilon, epsilon) {
487                                 return false
488                         }
489                 }
490         }
491         return true
492 }
493
494 // LogDet returns the log of the determinant and the sign of the determinant
495 // for the matrix that has been factorized. Numerical stability in product and
496 // division expressions is generally improved by working in log space.
497 func LogDet(a Matrix) (det float64, sign float64) {
498         // TODO(btracey): Add specialized routines for TriDense, etc.
499         var lu LU
500         lu.Factorize(a)
501         return lu.LogDet()
502 }
503
504 // Max returns the largest element value of the matrix A.
505 // Max will panic with matrix.ErrShape if the matrix has zero size.
506 func Max(a Matrix) float64 {
507         r, c := a.Dims()
508         if r == 0 || c == 0 {
509                 panic(ErrShape)
510         }
511         // Max(A) = Max(A^T)
512         aU, _ := untranspose(a)
513         switch m := aU.(type) {
514         case RawMatrixer:
515                 rm := m.RawMatrix()
516                 max := math.Inf(-1)
517                 for i := 0; i < rm.Rows; i++ {
518                         for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
519                                 if v > max {
520                                         max = v
521                                 }
522                         }
523                 }
524                 return max
525         case RawTriangular:
526                 rm := m.RawTriangular()
527                 // The max of a triangular is at least 0 unless the size is 1.
528                 if rm.N == 1 {
529                         return rm.Data[0]
530                 }
531                 max := 0.0
532                 if rm.Uplo == blas.Upper {
533                         for i := 0; i < rm.N; i++ {
534                                 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
535                                         if v > max {
536                                                 max = v
537                                         }
538                                 }
539                         }
540                         return max
541                 }
542                 for i := 0; i < rm.N; i++ {
543                         for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
544                                 if v > max {
545                                         max = v
546                                 }
547                         }
548                 }
549                 return max
550         case RawSymmetricer:
551                 rm := m.RawSymmetric()
552                 if rm.Uplo != blas.Upper {
553                         panic(badSymTriangle)
554                 }
555                 max := math.Inf(-1)
556                 for i := 0; i < rm.N; i++ {
557                         for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
558                                 if v > max {
559                                         max = v
560                                 }
561                         }
562                 }
563                 return max
564         default:
565                 r, c := aU.Dims()
566                 max := math.Inf(-1)
567                 for i := 0; i < r; i++ {
568                         for j := 0; j < c; j++ {
569                                 v := aU.At(i, j)
570                                 if v > max {
571                                         max = v
572                                 }
573                         }
574                 }
575                 return max
576         }
577 }
578
579 // Min returns the smallest element value of the matrix A.
580 // Min will panic with matrix.ErrShape if the matrix has zero size.
581 func Min(a Matrix) float64 {
582         r, c := a.Dims()
583         if r == 0 || c == 0 {
584                 panic(ErrShape)
585         }
586         // Min(A) = Min(A^T)
587         aU, _ := untranspose(a)
588         switch m := aU.(type) {
589         case RawMatrixer:
590                 rm := m.RawMatrix()
591                 min := math.Inf(1)
592                 for i := 0; i < rm.Rows; i++ {
593                         for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
594                                 if v < min {
595                                         min = v
596                                 }
597                         }
598                 }
599                 return min
600         case RawTriangular:
601                 rm := m.RawTriangular()
602                 // The min of a triangular is at most 0 unless the size is 1.
603                 if rm.N == 1 {
604                         return rm.Data[0]
605                 }
606                 min := 0.0
607                 if rm.Uplo == blas.Upper {
608                         for i := 0; i < rm.N; i++ {
609                                 for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
610                                         if v < min {
611                                                 min = v
612                                         }
613                                 }
614                         }
615                         return min
616                 }
617                 for i := 0; i < rm.N; i++ {
618                         for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+i+1] {
619                                 if v < min {
620                                         min = v
621                                 }
622                         }
623                 }
624                 return min
625         case RawSymmetricer:
626                 rm := m.RawSymmetric()
627                 if rm.Uplo != blas.Upper {
628                         panic(badSymTriangle)
629                 }
630                 min := math.Inf(1)
631                 for i := 0; i < rm.N; i++ {
632                         for _, v := range rm.Data[i*rm.Stride+i : i*rm.Stride+rm.N] {
633                                 if v < min {
634                                         min = v
635                                 }
636                         }
637                 }
638                 return min
639         default:
640                 r, c := aU.Dims()
641                 min := math.Inf(1)
642                 for i := 0; i < r; i++ {
643                         for j := 0; j < c; j++ {
644                                 v := aU.At(i, j)
645                                 if v < min {
646                                         min = v
647                                 }
648                         }
649                 }
650                 return min
651         }
652 }
653
654 // Norm returns the specified (induced) norm of the matrix a. See
655 // https://en.wikipedia.org/wiki/Matrix_norm for the definition of an induced norm.
656 //
657 // Valid norms are:
658 //    1 - The maximum absolute column sum
659 //    2 - Frobenius norm, the square root of the sum of the squares of the elements.
660 //  Inf - The maximum absolute row sum.
661 // Norm will panic with ErrNormOrder if an illegal norm order is specified and
662 // with matrix.ErrShape if the matrix has zero size.
663 func Norm(a Matrix, norm float64) float64 {
664         r, c := a.Dims()
665         if r == 0 || c == 0 {
666                 panic(ErrShape)
667         }
668         aU, aTrans := untranspose(a)
669         var work []float64
670         switch rma := aU.(type) {
671         case RawMatrixer:
672                 rm := rma.RawMatrix()
673                 n := normLapack(norm, aTrans)
674                 if n == lapack.MaxColumnSum {
675                         work = getFloats(rm.Cols, false)
676                         defer putFloats(work)
677                 }
678                 return lapack64.Lange(n, rm, work)
679         case RawTriangular:
680                 rm := rma.RawTriangular()
681                 n := normLapack(norm, aTrans)
682                 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
683                         work = getFloats(rm.N, false)
684                         defer putFloats(work)
685                 }
686                 return lapack64.Lantr(n, rm, work)
687         case RawSymmetricer:
688                 rm := rma.RawSymmetric()
689                 n := normLapack(norm, aTrans)
690                 if n == lapack.MaxRowSum || n == lapack.MaxColumnSum {
691                         work = getFloats(rm.N, false)
692                         defer putFloats(work)
693                 }
694                 return lapack64.Lansy(n, rm, work)
695         case *VecDense:
696                 rv := rma.RawVector()
697                 switch norm {
698                 default:
699                         panic("unreachable")
700                 case 1:
701                         if aTrans {
702                                 imax := blas64.Iamax(rma.n, rv)
703                                 return math.Abs(rma.At(imax, 0))
704                         }
705                         return blas64.Asum(rma.n, rv)
706                 case 2:
707                         return blas64.Nrm2(rma.n, rv)
708                 case math.Inf(1):
709                         if aTrans {
710                                 return blas64.Asum(rma.n, rv)
711                         }
712                         imax := blas64.Iamax(rma.n, rv)
713                         return math.Abs(rma.At(imax, 0))
714                 }
715         }
716         switch norm {
717         default:
718                 panic("unreachable")
719         case 1:
720                 var max float64
721                 for j := 0; j < c; j++ {
722                         var sum float64
723                         for i := 0; i < r; i++ {
724                                 sum += math.Abs(a.At(i, j))
725                         }
726                         if sum > max {
727                                 max = sum
728                         }
729                 }
730                 return max
731         case 2:
732                 var sum float64
733                 for i := 0; i < r; i++ {
734                         for j := 0; j < c; j++ {
735                                 v := a.At(i, j)
736                                 sum += v * v
737                         }
738                 }
739                 return math.Sqrt(sum)
740         case math.Inf(1):
741                 var max float64
742                 for i := 0; i < r; i++ {
743                         var sum float64
744                         for j := 0; j < c; j++ {
745                                 sum += math.Abs(a.At(i, j))
746                         }
747                         if sum > max {
748                                 max = sum
749                         }
750                 }
751                 return max
752         }
753 }
754
755 // normLapack converts the float64 norm input in Norm to a lapack.MatrixNorm.
756 func normLapack(norm float64, aTrans bool) lapack.MatrixNorm {
757         switch norm {
758         case 1:
759                 n := lapack.MaxColumnSum
760                 if aTrans {
761                         n = lapack.MaxRowSum
762                 }
763                 return n
764         case 2:
765                 return lapack.NormFrob
766         case math.Inf(1):
767                 n := lapack.MaxRowSum
768                 if aTrans {
769                         n = lapack.MaxColumnSum
770                 }
771                 return n
772         default:
773                 panic(ErrNormOrder)
774         }
775 }
776
777 // Sum returns the sum of the elements of the matrix.
778 func Sum(a Matrix) float64 {
779         // TODO(btracey): Add a fast path for the other supported matrix types.
780
781         r, c := a.Dims()
782         var sum float64
783         aU, _ := untranspose(a)
784         if rma, ok := aU.(RawMatrixer); ok {
785                 rm := rma.RawMatrix()
786                 for i := 0; i < rm.Rows; i++ {
787                         for _, v := range rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols] {
788                                 sum += v
789                         }
790                 }
791                 return sum
792         }
793         for i := 0; i < r; i++ {
794                 for j := 0; j < c; j++ {
795                         sum += a.At(i, j)
796                 }
797         }
798         return sum
799 }
800
801 // Trace returns the trace of the matrix. Trace will panic if the
802 // matrix is not square.
803 func Trace(a Matrix) float64 {
804         r, c := a.Dims()
805         if r != c {
806                 panic(ErrSquare)
807         }
808
809         aU, _ := untranspose(a)
810         switch m := aU.(type) {
811         case RawMatrixer:
812                 rm := m.RawMatrix()
813                 var t float64
814                 for i := 0; i < r; i++ {
815                         t += rm.Data[i*rm.Stride+i]
816                 }
817                 return t
818         case RawTriangular:
819                 rm := m.RawTriangular()
820                 var t float64
821                 for i := 0; i < r; i++ {
822                         t += rm.Data[i*rm.Stride+i]
823                 }
824                 return t
825         case RawSymmetricer:
826                 rm := m.RawSymmetric()
827                 var t float64
828                 for i := 0; i < r; i++ {
829                         t += rm.Data[i*rm.Stride+i]
830                 }
831                 return t
832         default:
833                 var t float64
834                 for i := 0; i < r; i++ {
835                         t += a.At(i, i)
836                 }
837                 return t
838         }
839 }
840
841 func min(a, b int) int {
842         if a < b {
843                 return a
844         }
845         return b
846 }
847
848 func max(a, b int) int {
849         if a > b {
850                 return a
851         }
852         return b
853 }
854
855 // use returns a float64 slice with l elements, using f if it
856 // has the necessary capacity, otherwise creating a new slice.
857 func use(f []float64, l int) []float64 {
858         if l <= cap(f) {
859                 return f[:l]
860         }
861         return make([]float64, l)
862 }
863
864 // useZeroed returns a float64 slice with l elements, using f if it
865 // has the necessary capacity, otherwise creating a new slice. The
866 // elements of the returned slice are guaranteed to be zero.
867 func useZeroed(f []float64, l int) []float64 {
868         if l <= cap(f) {
869                 f = f[:l]
870                 zero(f)
871                 return f
872         }
873         return make([]float64, l)
874 }
875
876 // zero zeros the given slice's elements.
877 func zero(f []float64) {
878         for i := range f {
879                 f[i] = 0
880         }
881 }
882
883 // useInt returns an int slice with l elements, using i if it
884 // has the necessary capacity, otherwise creating a new slice.
885 func useInt(i []int, l int) []int {
886         if l <= cap(i) {
887                 return i[:l]
888         }
889         return make([]int, l)
890 }