1 // Copyright ©2013 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "gonum.org/v1/gonum/blas"
11 "gonum.org/v1/gonum/blas/blas64"
12 "gonum.org/v1/gonum/lapack/lapack64"
15 // Add adds a and b element-wise, placing the result in the receiver. Add
16 // will panic if the two matrices do not have the same shape.
17 func (m *Dense) Add(a, b Matrix) {
20 if ar != br || ac != bc {
24 aU, _ := untranspose(a)
25 bU, _ := untranspose(b)
28 if arm, ok := a.(RawMatrixer); ok {
29 if brm, ok := b.(RawMatrixer); ok {
30 amat, bmat := arm.RawMatrix(), brm.RawMatrix()
37 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
38 for i, v := range amat.Data[ja : ja+ac] {
39 m.mat.Data[i+jm] = v + bmat.Data[i+jb]
48 m, restore = m.isolatedWorkspace(aU)
51 m, restore = m.isolatedWorkspace(bU)
55 for r := 0; r < ar; r++ {
56 for c := 0; c < ac; c++ {
57 m.set(r, c, a.At(r, c)+b.At(r, c))
62 // Sub subtracts the matrix b from a, placing the result in the receiver. Sub
63 // will panic if the two matrices do not have the same shape.
64 func (m *Dense) Sub(a, b Matrix) {
67 if ar != br || ac != bc {
71 aU, _ := untranspose(a)
72 bU, _ := untranspose(b)
75 if arm, ok := a.(RawMatrixer); ok {
76 if brm, ok := b.(RawMatrixer); ok {
77 amat, bmat := arm.RawMatrix(), brm.RawMatrix()
84 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
85 for i, v := range amat.Data[ja : ja+ac] {
86 m.mat.Data[i+jm] = v - bmat.Data[i+jb]
95 m, restore = m.isolatedWorkspace(aU)
98 m, restore = m.isolatedWorkspace(bU)
102 for r := 0; r < ar; r++ {
103 for c := 0; c < ac; c++ {
104 m.set(r, c, a.At(r, c)-b.At(r, c))
109 // MulElem performs element-wise multiplication of a and b, placing the result
110 // in the receiver. MulElem will panic if the two matrices do not have the same
112 func (m *Dense) MulElem(a, b Matrix) {
115 if ar != br || ac != bc {
119 aU, _ := untranspose(a)
120 bU, _ := untranspose(b)
123 if arm, ok := a.(RawMatrixer); ok {
124 if brm, ok := b.(RawMatrixer); ok {
125 amat, bmat := arm.RawMatrix(), brm.RawMatrix()
132 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
133 for i, v := range amat.Data[ja : ja+ac] {
134 m.mat.Data[i+jm] = v * bmat.Data[i+jb]
143 m, restore = m.isolatedWorkspace(aU)
146 m, restore = m.isolatedWorkspace(bU)
150 for r := 0; r < ar; r++ {
151 for c := 0; c < ac; c++ {
152 m.set(r, c, a.At(r, c)*b.At(r, c))
157 // DivElem performs element-wise division of a by b, placing the result
158 // in the receiver. DivElem will panic if the two matrices do not have the same
160 func (m *Dense) DivElem(a, b Matrix) {
163 if ar != br || ac != bc {
167 aU, _ := untranspose(a)
168 bU, _ := untranspose(b)
171 if arm, ok := a.(RawMatrixer); ok {
172 if brm, ok := b.(RawMatrixer); ok {
173 amat, bmat := arm.RawMatrix(), brm.RawMatrix()
180 for ja, jb, jm := 0, 0, 0; ja < ar*amat.Stride; ja, jb, jm = ja+amat.Stride, jb+bmat.Stride, jm+m.mat.Stride {
181 for i, v := range amat.Data[ja : ja+ac] {
182 m.mat.Data[i+jm] = v / bmat.Data[i+jb]
191 m, restore = m.isolatedWorkspace(aU)
194 m, restore = m.isolatedWorkspace(bU)
198 for r := 0; r < ar; r++ {
199 for c := 0; c < ac; c++ {
200 m.set(r, c, a.At(r, c)/b.At(r, c))
205 // Inverse computes the inverse of the matrix a, storing the result into the
206 // receiver. If a is ill-conditioned, a Condition error will be returned.
207 // Note that matrix inversion is numerically unstable, and should generally
208 // be avoided where possible, for example by using the Solve routines.
209 func (m *Dense) Inverse(a Matrix) error {
210 // TODO(btracey): Special case for RawTriangular, etc.
216 aU, aTrans := untranspose(a)
217 switch rm := aU.(type) {
219 if m != aU || aTrans {
220 if m == aU || m.checkOverlap(rm.RawMatrix()) {
221 tmp := getWorkspace(r, c, false)
232 ipiv := getInts(r, false)
234 ok := lapack64.Getrf(m.mat, ipiv)
236 return Condition(math.Inf(1))
238 work := getFloats(4*r, false) // must be at least 4*r for cond.
239 lapack64.Getri(m.mat, ipiv, work, -1)
240 if int(work[0]) > 4*r {
243 work = getFloats(l, false)
247 defer putFloats(work)
248 lapack64.Getri(m.mat, ipiv, work, len(work))
249 norm := lapack64.Lange(CondNorm, m.mat, work)
250 rcond := lapack64.Gecon(CondNorm, m.mat, norm, work, ipiv) // reuse ipiv
252 return Condition(math.Inf(1))
255 if cond > ConditionTolerance {
256 return Condition(cond)
261 // Mul takes the matrix product of a and b, placing the result in the receiver.
262 // If the number of columns in a does not equal the number of rows in b, Mul will panic.
263 func (m *Dense) Mul(a, b Matrix) {
271 aU, aTrans := untranspose(a)
272 bU, bTrans := untranspose(b)
276 m, restore = m.isolatedWorkspace(aU)
279 m, restore = m.isolatedWorkspace(bU)
291 // Some of the cases do not have a transpose option, so create
293 // C = A^T * B = (B^T * A)^T
295 if aUrm, ok := aU.(RawMatrixer); ok {
296 amat := aUrm.RawMatrix()
300 if bUrm, ok := bU.(RawMatrixer); ok {
301 bmat := bUrm.RawMatrix()
305 blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
308 if bU, ok := bU.(RawSymmetricer); ok {
309 bmat := bU.RawSymmetric()
311 c := getWorkspace(ac, ar, false)
312 blas64.Symm(blas.Left, 1, bmat, amat, 0, c.mat)
317 blas64.Symm(blas.Right, 1, bmat, amat, 0, m.mat)
320 if bU, ok := bU.(RawTriangular); ok {
321 // Trmm updates in place, so copy aU first.
322 bmat := bU.RawTriangular()
324 c := getWorkspace(ac, ar, false)
326 tmp.SetRawMatrix(amat)
332 blas64.Trmm(blas.Left, bT, 1, bmat, c.mat)
338 blas64.Trmm(blas.Right, bT, 1, bmat, m.mat)
341 if bU, ok := bU.(*VecDense); ok {
342 m.checkOverlap(bU.asGeneral())
343 bvec := bU.RawVector()
345 // {ar,1} x {1,bc}, which is not a vector.
346 // Instead, construct B as a General.
347 bmat := blas64.General{
353 blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
356 cvec := blas64.Vector{
360 blas64.Gemv(aT, 1, amat, bvec, 0, cvec)
364 if bUrm, ok := bU.(RawMatrixer); ok {
365 bmat := bUrm.RawMatrix()
369 if aU, ok := aU.(RawSymmetricer); ok {
370 amat := aU.RawSymmetric()
372 c := getWorkspace(bc, br, false)
373 blas64.Symm(blas.Right, 1, amat, bmat, 0, c.mat)
378 blas64.Symm(blas.Left, 1, amat, bmat, 0, m.mat)
381 if aU, ok := aU.(RawTriangular); ok {
382 // Trmm updates in place, so copy bU first.
383 amat := aU.RawTriangular()
385 c := getWorkspace(bc, br, false)
387 tmp.SetRawMatrix(bmat)
393 blas64.Trmm(blas.Right, aT, 1, amat, c.mat)
399 blas64.Trmm(blas.Left, aT, 1, amat, m.mat)
402 if aU, ok := aU.(*VecDense); ok {
403 m.checkOverlap(aU.asGeneral())
404 avec := aU.RawVector()
407 // Transpose B so that the vector is on the right.
408 cvec := blas64.Vector{
416 blas64.Gemv(bT, 1, bmat, avec, 0, cvec)
419 // {ar,1} x {1,bc} which is not a vector result.
420 // Instead, construct A as a General.
421 amat := blas64.General{
427 blas64.Gemm(aT, bT, 1, amat, bmat, 0, m.mat)
432 row := getFloats(ac, false)
434 for r := 0; r < ar; r++ {
438 for c := 0; c < bc; c++ {
440 for i, e := range row {
443 m.mat.Data[r*m.mat.Stride+c] = v
448 // strictCopy copies a into m panicking if the shape of a and m differ.
449 func strictCopy(m *Dense, a Matrix) {
451 if r != m.mat.Rows || c != m.mat.Cols {
452 // Panic with a string since this
453 // is not a user-facing panic.
454 panic(ErrShape.Error())
458 // Exp calculates the exponential of the matrix a, e^a, placing the result
459 // in the receiver. Exp will panic with matrix.ErrShape if a is not square.
461 // Exp uses the scaling and squaring method described in section 3 of
462 // http://www.cs.cornell.edu/cv/researchpdf/19ways+.pdf.
463 func (m *Dense) Exp(a Matrix) {
471 m.reuseAsZeroed(r, r)
474 w = getWorkspace(r, r, true)
476 for i := 0; i < r*r; i += r + 1 {
485 small := getWorkspace(r, r, false)
486 small.Scale(math.Pow(2, -scaling), a)
487 power := getWorkspace(r, r, false)
491 tmp = getWorkspace(r, r, false)
494 for i := 1.; i < terms; i++ {
497 // This is OK to do because power and tmp are
498 // new Dense values so all rows are contiguous.
499 // TODO(kortschak) Make this explicit in the NewDense doc comment.
500 for j, v := range power.mat.Data {
501 tmp.mat.Data[j] = v / factI
506 tmp.Mul(power, small)
507 tmp, power = power, tmp
512 for i := 0; i < scaling; i++ {
524 // Pow calculates the integral power of the matrix a to n, placing the result
525 // in the receiver. Pow will panic if n is negative or if a is not square.
526 func (m *Dense) Pow(a Matrix, n int) {
528 panic("matrix: illegal power")
537 // Take possible fast paths.
540 for i := 0; i < r; i++ {
541 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
542 m.mat.Data[i*m.mat.Stride+i] = 1
553 // Perform iterative exponentiation by squaring in work space.
554 w := getWorkspace(r, r, false)
556 s := getWorkspace(r, r, false)
558 x := getWorkspace(r, r, false)
559 for n--; n > 0; n >>= 1 {
575 // Scale multiplies the elements of a by f, placing the result in the receiver.
577 // See the Scaler interface for more information.
578 func (m *Dense) Scale(f float64, a Matrix) {
583 aU, aTrans := untranspose(a)
584 if rm, ok := aU.(RawMatrixer); ok {
585 amat := rm.RawMatrix()
586 if m == aU || m.checkOverlap(amat) {
588 m, restore = m.isolatedWorkspace(a)
592 for ja, jm := 0, 0; ja < ar*amat.Stride; ja, jm = ja+amat.Stride, jm+m.mat.Stride {
593 for i, v := range amat.Data[ja : ja+ac] {
594 m.mat.Data[i+jm] = v * f
598 for ja, jm := 0, 0; ja < ac*amat.Stride; ja, jm = ja+amat.Stride, jm+1 {
599 for i, v := range amat.Data[ja : ja+ar] {
600 m.mat.Data[i*m.mat.Stride+jm] = v * f
607 for r := 0; r < ar; r++ {
608 for c := 0; c < ac; c++ {
609 m.set(r, c, f*a.At(r, c))
614 // Apply applies the function fn to each of the elements of a, placing the
615 // resulting matrix in the receiver. The function fn takes a row/column
616 // index and element value and returns some function of that tuple.
617 func (m *Dense) Apply(fn func(i, j int, v float64) float64, a Matrix) {
622 aU, aTrans := untranspose(a)
623 if rm, ok := aU.(RawMatrixer); ok {
624 amat := rm.RawMatrix()
625 if m == aU || m.checkOverlap(amat) {
627 m, restore = m.isolatedWorkspace(a)
631 for j, ja, jm := 0, 0, 0; ja < ar*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+m.mat.Stride {
632 for i, v := range amat.Data[ja : ja+ac] {
633 m.mat.Data[i+jm] = fn(j, i, v)
637 for j, ja, jm := 0, 0, 0; ja < ac*amat.Stride; j, ja, jm = j+1, ja+amat.Stride, jm+1 {
638 for i, v := range amat.Data[ja : ja+ar] {
639 m.mat.Data[i*m.mat.Stride+jm] = fn(i, j, v)
646 for r := 0; r < ar; r++ {
647 for c := 0; c < ac; c++ {
648 m.set(r, c, fn(r, c, a.At(r, c)))
653 // RankOne performs a rank-one update to the matrix a and stores the result
654 // in the receiver. If a is zero, see Outer.
655 // m = a + alpha * x * y'
656 func (m *Dense) RankOne(a Matrix, alpha float64, x, y Vector) {
659 if xr != ar || xc != 1 {
663 if yr != ac || yc != 1 {
668 aU, _ := untranspose(a)
669 if rm, ok := aU.(RawMatrixer); ok {
670 m.checkOverlap(rm.RawMatrix())
674 var xmat, ymat blas64.Vector
676 xU, _ := untranspose(x)
677 if rv, ok := xU.(RawVectorer); ok {
678 xmat = rv.RawVector()
679 m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
683 yU, _ := untranspose(y)
684 if rv, ok := yU.(RawVectorer); ok {
685 ymat = rv.RawVector()
686 m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
696 blas64.Ger(alpha, xmat, ymat, m.mat)
701 for i := 0; i < ar; i++ {
702 for j := 0; j < ac; j++ {
703 m.set(i, j, a.At(i, j)+alpha*x.AtVec(i)*y.AtVec(j))
708 // Outer calculates the outer product of the column vectors x and y,
709 // and stores the result in the receiver.
710 // m = alpha * x * y'
711 // In order to update an existing matrix, see RankOne.
712 func (m *Dense) Outer(alpha float64, x, y Vector) {
725 // Copied from reuseAs with use replaced by useZeroed
726 // and a final zero of the matrix elements if we pass
728 // TODO(kortschak): Factor out into reuseZeroedAs if
729 // we find another case that needs it.
730 if m.mat.Rows > m.capRows || m.mat.Cols > m.capCols {
731 // Panic as a string, not a mat.Error.
732 panic("mat: caps not correctly set")
735 m.mat = blas64.General{
739 Data: useZeroed(m.mat.Data, r*c),
743 } else if r != m.mat.Rows || c != m.mat.Cols {
747 var xmat, ymat blas64.Vector
749 xU, _ := untranspose(x)
750 if rv, ok := xU.(RawVectorer); ok {
751 xmat = rv.RawVector()
752 m.checkOverlap((&VecDense{mat: xmat, n: x.Len()}).asGeneral())
756 yU, _ := untranspose(y)
757 if rv, ok := yU.(RawVectorer); ok {
758 ymat = rv.RawVector()
759 m.checkOverlap((&VecDense{mat: ymat, n: y.Len()}).asGeneral())
765 for i := 0; i < r; i++ {
766 zero(m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c])
768 blas64.Ger(alpha, xmat, ymat, m.mat)
772 for i := 0; i < r; i++ {
773 for j := 0; j < c; j++ {
774 m.set(i, j, alpha*x.AtVec(i)*y.AtVec(j))