1 // Copyright ©2015 The Gonum Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
10 "gonum.org/v1/gonum/blas"
11 "gonum.org/v1/gonum/blas/blas64"
14 // Dlatrs solves a triangular system of equations scaled to prevent overflow. It
16 // A * x = scale * b if trans == blas.NoTrans
17 // A^T * x = scale * b if trans == blas.Trans
18 // where the scale s is set for numeric stability.
20 // A is an n×n triangular matrix. On entry, the slice x contains the values of
21 // of b, and on exit it contains the solution vector x.
23 // If normin == true, cnorm is an input and cnorm[j] contains the norm of the off-diagonal
24 // part of the j^th column of A. If trans == blas.NoTrans, cnorm[j] must be greater
25 // than or equal to the infinity norm, and greater than or equal to the one-norm
26 // otherwise. If normin == false, then cnorm is treated as an output, and is set
27 // to contain the 1-norm of the off-diagonal part of the j^th column of A.
29 // Dlatrs is an internal routine. It is exported for testing purposes.
30 func (impl Implementation) Dlatrs(uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, normin bool, n int, a []float64, lda int, x []float64, cnorm []float64) (scale float64) {
31 if uplo != blas.Upper && uplo != blas.Lower {
34 if trans != blas.Trans && trans != blas.NoTrans {
37 if diag != blas.Unit && diag != blas.NonUnit {
40 upper := uplo == blas.Upper
41 noTrans := trans == blas.NoTrans
42 nonUnit := diag == blas.NonUnit
47 checkMatrix(n, n, a, lda)
49 checkVector(n, cnorm, 1)
54 smlnum := dlamchS / dlamchP
57 bi := blas64.Implementation()
61 for j := 1; j < n; j++ {
62 cnorm[j] = bi.Dasum(j, a[j:], lda)
65 for j := 0; j < n-1; j++ {
66 cnorm[j] = bi.Dasum(n-j-1, a[(j+1)*lda+j:], lda)
71 // Scale the column norms by tscal if the maximum element in cnorm is greater than bignum.
72 imax := bi.Idamax(n, cnorm, 1)
78 tscal = 1 / (smlnum * tmax)
79 bi.Dscal(n, tscal, cnorm, 1)
82 // Compute a bound on the computed solution vector to see if bi.Dtrsv can be used.
83 j := bi.Idamax(n, x, 1)
84 xmax := math.Abs(x[j])
87 var jfirst, jlast, jinc int
98 // Compute the growth in A * x = b.
104 grow = 1 / math.Max(xbnd, smlnum)
106 for j := jfirst; j != jlast; j += jinc {
110 tjj := math.Abs(a[j*lda+j])
111 xbnd = math.Min(xbnd, math.Min(1, tjj)*grow)
112 if tjj+cnorm[j] >= smlnum {
113 grow *= tjj / (tjj + cnorm[j])
120 grow = math.Min(1, 1/math.Max(xbnd, smlnum))
121 for j := jfirst; j != jlast; j += jinc {
125 grow *= 1 / (1 + cnorm[j])
143 grow = 1 / (math.Max(xbnd, smlnum))
145 for j := jfirst; j != jlast; j += jinc {
150 grow = math.Min(grow, xbnd/xj)
151 tjj := math.Abs(a[j*lda+j])
156 grow = math.Min(grow, xbnd)
158 grow = math.Min(1, 1/math.Max(xbnd, smlnum))
159 for j := jfirst; j != jlast; j += jinc {
170 if grow*tscal > smlnum {
171 // Use the Level 2 BLAS solve if the reciprocal of the bound on
172 // elements of X is not too small.
173 bi.Dtrsv(uplo, trans, diag, n, a, lda, x, 1)
175 bi.Dscal(n, 1/tscal, cnorm, 1)
180 // Use a Level 1 BLAS solve, scaling intermediate results.
182 scale = bignum / xmax
183 bi.Dscal(n, scale, x, 1)
187 for j := jfirst; j != jlast; j += jinc {
189 var tjj, tjjs float64
191 tjjs = a[j*lda+j] * tscal
203 bi.Dscal(n, rec, x, 1)
212 rec := (tjj * bignum) / xj
216 bi.Dscal(n, rec, x, 1)
223 for i := 0; i < n; i++ {
234 if cnorm[j] > (bignum-xmax)*rec {
236 bi.Dscal(n, rec, x, 1)
239 } else if xj*cnorm[j] > bignum-xmax {
240 bi.Dscal(n, 0.5, x, 1)
245 bi.Daxpy(j, -x[j]*tscal, a[j:], lda, x, 1)
246 i := bi.Idamax(j, x, 1)
247 xmax = math.Abs(x[i])
251 bi.Daxpy(n-j-1, -x[j]*tscal, a[(j+1)*lda+j:], lda, x[j+1:], 1)
252 i := j + bi.Idamax(n-j-1, x[j+1:], 1)
253 xmax = math.Abs(x[i])
258 for j := jfirst; j != jlast; j += jinc {
261 rec := 1 / math.Max(xmax, 1)
263 if cnorm[j] > (bignum-xj)*rec {
266 tjjs = a[j*lda+j] * tscal
270 tjj := math.Abs(tjjs)
272 rec = math.Min(1, rec*tjj)
276 bi.Dscal(n, rec, x, 1)
284 sumj = bi.Ddot(j, a[j:], lda, x, 1)
286 sumj = bi.Ddot(n-j-1, a[(j+1)*lda+j:], lda, x[j+1:], 1)
290 for i := 0; i < j; i++ {
291 sumj += (a[i*lda+j] * uscal) * x[i]
294 for i := j + 1; i < n; i++ {
295 sumj += (a[i*lda+j] * uscal) * x[i]
304 tjjs = a[j*lda+j] * tscal
311 tjj := math.Abs(tjjs)
316 bi.Dscal(n, rec, x, 1)
324 rec = (tjj * bignum) / xj
325 bi.Dscal(n, rec, x, 1)
331 for i := 0; i < n; i++ {
339 x[j] = x[j]/tjjs - sumj
342 xmax = math.Max(xmax, math.Abs(x[j]))
347 bi.Dscal(n, 1/tscal, cnorm, 1)