OSDN Git Service

new repo
[bytom/vapor.git] / vendor / gonum.org / v1 / gonum / mat / inner.go
1 // Copyright ©2014 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package mat
6
7 import (
8         "gonum.org/v1/gonum/blas"
9         "gonum.org/v1/gonum/blas/blas64"
10         "gonum.org/v1/gonum/internal/asm/f64"
11 )
12
13 // Inner computes the generalized inner product
14 //   x^T A y
15 // between column vectors x and y with matrix A. This is only a true inner product if
16 // A is symmetric positive definite, though the operation works for any matrix A.
17 //
18 // Inner panics if x.Len != m or y.Len != n when A is an m x n matrix.
19 func Inner(x Vector, a Matrix, y Vector) float64 {
20         m, n := a.Dims()
21         if x.Len() != m {
22                 panic(ErrShape)
23         }
24         if y.Len() != n {
25                 panic(ErrShape)
26         }
27         if m == 0 || n == 0 {
28                 return 0
29         }
30
31         var sum float64
32
33         switch a := a.(type) {
34         case RawSymmetricer:
35                 amat := a.RawSymmetric()
36                 if amat.Uplo != blas.Upper {
37                         // Panic as a string not a mat.Error.
38                         panic(badSymTriangle)
39                 }
40                 var xmat, ymat blas64.Vector
41                 if xrv, ok := x.(RawVectorer); ok {
42                         xmat = xrv.RawVector()
43                 } else {
44                         break
45                 }
46                 if yrv, ok := y.(RawVectorer); ok {
47                         ymat = yrv.RawVector()
48                 } else {
49                         break
50                 }
51                 for i := 0; i < x.Len(); i++ {
52                         xi := x.AtVec(i)
53                         if xi != 0 {
54                                 if ymat.Inc == 1 {
55                                         sum += xi * f64.DotUnitary(
56                                                 amat.Data[i*amat.Stride+i:i*amat.Stride+n],
57                                                 ymat.Data[i:],
58                                         )
59                                 } else {
60                                         sum += xi * f64.DotInc(
61                                                 amat.Data[i*amat.Stride+i:i*amat.Stride+n],
62                                                 ymat.Data[i*ymat.Inc:], uintptr(n-i),
63                                                 1, uintptr(ymat.Inc),
64                                                 0, 0,
65                                         )
66                                 }
67                         }
68                         yi := y.AtVec(i)
69                         if i != n-1 && yi != 0 {
70                                 if xmat.Inc == 1 {
71                                         sum += yi * f64.DotUnitary(
72                                                 amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
73                                                 xmat.Data[i+1:],
74                                         )
75                                 } else {
76                                         sum += yi * f64.DotInc(
77                                                 amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
78                                                 xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1),
79                                                 1, uintptr(xmat.Inc),
80                                                 0, 0,
81                                         )
82                                 }
83                         }
84                 }
85                 return sum
86         case RawMatrixer:
87                 amat := a.RawMatrix()
88                 var ymat blas64.Vector
89                 if yrv, ok := y.(RawVectorer); ok {
90                         ymat = yrv.RawVector()
91                 } else {
92                         break
93                 }
94                 for i := 0; i < x.Len(); i++ {
95                         xi := x.AtVec(i)
96                         if xi != 0 {
97                                 if ymat.Inc == 1 {
98                                         sum += xi * f64.DotUnitary(
99                                                 amat.Data[i*amat.Stride:i*amat.Stride+n],
100                                                 ymat.Data,
101                                         )
102                                 } else {
103                                         sum += xi * f64.DotInc(
104                                                 amat.Data[i*amat.Stride:i*amat.Stride+n],
105                                                 ymat.Data, uintptr(n),
106                                                 1, uintptr(ymat.Inc),
107                                                 0, 0,
108                                         )
109                                 }
110                         }
111                 }
112                 return sum
113         }
114         for i := 0; i < x.Len(); i++ {
115                 xi := x.AtVec(i)
116                 for j := 0; j < y.Len(); j++ {
117                         sum += xi * a.At(i, j) * y.AtVec(j)
118                 }
119         }
120         return sum
121 }