OSDN Git Service

versoin1.1.9 (#594)
[bytom/vapor.git] / p2p / connection / secret_connection_test.go
1 package connection
2
3 import (
4         "bytes"
5         "io"
6         "testing"
7
8         cmn "github.com/tendermint/tmlibs/common"
9
10         "github.com/bytom/vapor/crypto/ed25519/chainkd"
11 )
12
13 type dummyConn struct {
14         *io.PipeReader
15         *io.PipeWriter
16 }
17
18 func (drw dummyConn) Close() (err error) {
19         err2 := drw.PipeWriter.CloseWithError(io.EOF)
20         err1 := drw.PipeReader.Close()
21         if err2 != nil {
22                 return err
23         }
24         return err1
25 }
26
27 // Each returned ReadWriteCloser is akin to a net.Connection
28 func makeDummyConnPair() (fooConn, barConn dummyConn) {
29         barReader, fooWriter := io.Pipe()
30         fooReader, barWriter := io.Pipe()
31         return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter}
32 }
33
34 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
35         fooConn, barConn := makeDummyConnPair()
36         fooPrvKey, _ := chainkd.NewXPrv(nil)
37         fooPubKey := fooPrvKey.XPub()
38         barPrvKey, _ := chainkd.NewXPrv(nil)
39         barPubKey := barPrvKey.XPub()
40
41         cmn.Parallel(
42                 func() {
43                         var err error
44                         fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
45                         if err != nil {
46                                 tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
47                                 return
48                         }
49                         remotePubBytes := fooSecConn.RemotePubKey()
50                         if !bytes.Equal(remotePubBytes.Bytes()[:], barPubKey[:]) {
51                                 tb.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
52                                         barPubKey, fooSecConn.RemotePubKey())
53                         }
54                 },
55                 func() {
56                         var err error
57                         barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
58                         if barSecConn == nil {
59                                 tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
60                                 return
61                         }
62                         remotePubBytes := barSecConn.RemotePubKey()
63                         if !bytes.Equal(remotePubBytes.Bytes()[:], fooPubKey[:]) {
64                                 tb.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
65                                         fooPubKey, barSecConn.RemotePubKey())
66                         }
67                 })
68
69         return
70 }
71
72 func TestSecretConnectionHandshake(t *testing.T) {
73         fooSecConn, barSecConn := makeSecretConnPair(t)
74         fooSecConn.Close()
75         barSecConn.Close()
76 }
77
78 func TestSecretConnectionReadWrite(t *testing.T) {
79         fooConn, barConn := makeDummyConnPair()
80         fooWrites, barWrites := []string{}, []string{}
81         fooReads, barReads := []string{}, []string{}
82
83         // Pre-generate the things to write (for foo & bar)
84         for i := 0; i < 100; i++ {
85                 fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
86                 barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
87         }
88
89         // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
90         genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() {
91                 return func() {
92                         // Node handskae
93                         nodePrvKey, _ := chainkd.NewXPrv(nil)
94                         nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
95                         if err != nil {
96                                 t.Errorf("Failed to establish SecretConnection for node: %v", err)
97                                 return
98                         }
99                         // In parallel, handle reads and writes
100                         cmn.Parallel(
101                                 func() {
102                                         // Node writes
103                                         for _, nodeWrite := range nodeWrites {
104                                                 n, err := nodeSecretConn.Write([]byte(nodeWrite))
105                                                 if err != nil {
106                                                         t.Errorf("Failed to write to nodeSecretConn: %v", err)
107                                                         return
108                                                 }
109                                                 if n != len(nodeWrite) {
110                                                         t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
111                                                         return
112                                                 }
113                                         }
114                                         nodeConn.PipeWriter.Close()
115                                 },
116                                 func() {
117                                         // Node reads
118                                         readBuffer := make([]byte, dataMaxSize)
119                                         for {
120                                                 n, err := nodeSecretConn.Read(readBuffer)
121                                                 if err == io.EOF {
122                                                         return
123                                                 } else if err != nil {
124                                                         t.Errorf("Failed to read from nodeSecretConn: %v", err)
125                                                         return
126                                                 }
127                                                 *nodeReads = append(*nodeReads, string(readBuffer[:n]))
128                                         }
129                                         nodeConn.PipeReader.Close()
130                                 })
131                 }
132         }
133
134         // Run foo & bar in parallel
135         cmn.Parallel(
136                 genNodeRunner(fooConn, fooWrites, &fooReads),
137                 genNodeRunner(barConn, barWrites, &barReads),
138         )
139
140         // A helper to ensure that the writes and reads match.
141         // Additionally, small writes (<= dataMaxSize) must be atomically read.
142         compareWritesReads := func(writes []string, reads []string) {
143                 for {
144                         // Pop next write & corresponding reads
145                         var read, write string = "", writes[0]
146                         var readCount = 0
147                         for _, readChunk := range reads {
148                                 read += readChunk
149                                 readCount++
150                                 if len(write) <= len(read) {
151                                         break
152                                 }
153                                 if len(write) <= dataMaxSize {
154                                         break // atomicity of small writes
155                                 }
156                         }
157                         // Compare
158                         if write != read {
159                                 t.Errorf("Expected to read %X, got %X", write, read)
160                         }
161                         // Iterate
162                         writes = writes[1:]
163                         reads = reads[readCount:]
164                         if len(writes) == 0 {
165                                 break
166                         }
167                 }
168         }
169
170         compareWritesReads(fooWrites, barReads)
171         compareWritesReads(barWrites, fooReads)
172
173 }
174
175 func BenchmarkSecretConnection(b *testing.B) {
176         b.StopTimer()
177         fooSecConn, barSecConn := makeSecretConnPair(b)
178         fooWriteText := cmn.RandStr(dataMaxSize)
179         // Consume reads from bar's reader
180         go func() {
181                 readBuffer := make([]byte, dataMaxSize)
182                 for {
183                         _, err := barSecConn.Read(readBuffer)
184                         if err == io.EOF {
185                                 return
186                         } else if err != nil {
187                                 b.Fatalf("Failed to read from barSecConn: %v", err)
188                         }
189                 }
190         }()
191
192         b.StartTimer()
193         for i := 0; i < b.N; i++ {
194                 _, err := fooSecConn.Write([]byte(fooWriteText))
195                 if err != nil {
196                         b.Fatalf("Failed to write to fooSecConn: %v", err)
197                 }
198         }
199         b.StopTimer()
200
201         fooSecConn.Close()
202         //barSecConn.Close() race condition
203 }