6 "gonum.org/v1/gonum/blas"
9 type Dgemmer interface {
10 Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int)
13 type DgemmCase struct {
22 var DgemmCases = []DgemmCase{
134 {110.5, -69, 73, 39},
188 // assumes [][]float64 is actually a matrix
189 func transpose(a [][]float64) [][]float64 {
190 b := make([][]float64, len(a[0]))
192 b[i] = make([]float64, len(a))
193 for j := range b[i] {
200 func TestDgemm(t *testing.T, blasser Dgemmer) {
201 for i, test := range DgemmCases {
202 // Test that it passes row major
203 dgemmcomp(i, "RowMajorNoTrans", t, blasser, blas.NoTrans, blas.NoTrans,
204 test.m, test.n, test.k, test.alpha, test.beta, test.a, test.b, test.c, test.ans)
205 // Try with A transposed
206 dgemmcomp(i, "RowMajorTransA", t, blasser, blas.Trans, blas.NoTrans,
207 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), test.b, test.c, test.ans)
208 // Try with B transposed
209 dgemmcomp(i, "RowMajorTransB", t, blasser, blas.NoTrans, blas.Trans,
210 test.m, test.n, test.k, test.alpha, test.beta, test.a, transpose(test.b), test.c, test.ans)
211 // Try with both transposed
212 dgemmcomp(i, "RowMajorTransBoth", t, blasser, blas.Trans, blas.Trans,
213 test.m, test.n, test.k, test.alpha, test.beta, transpose(test.a), transpose(test.b), test.c, test.ans)
217 func dgemmcomp(i int, name string, t *testing.T, blasser Dgemmer, tA, tB blas.Transpose, m, n, k int,
218 alpha, beta float64, a [][]float64, b [][]float64, c [][]float64, ans [][]float64) {
225 ansFlat := flatten(ans)
230 // Compute the matrix multiplication
231 blasser.Dgemm(tA, tB, m, n, k, alpha, aFlat, lda, bFlat, ldb, beta, cFlat, ldc)
233 if !dSliceEqual(aFlat, aCopy) {
234 t.Errorf("Test %v case %v: a changed during call to Dgemm", i, name)
236 if !dSliceEqual(bFlat, bCopy) {
237 t.Errorf("Test %v case %v: b changed during call to Dgemm", i, name)
240 if !dSliceTolEqual(ansFlat, cFlat) {
241 t.Errorf("Test %v case %v: answer mismatch. Expected %v, Found %v", i, name, ansFlat, cFlat)
243 // TODO: Need to add a sub-slice test where don't use up full matrix