OSDN Git Service

Accelerate node dialing speed
[bytom/bytom.git] / p2p / secret_connection_test.go
1 // +build !network
2
3 package p2p
4
5 import (
6         "bytes"
7         "io"
8         "testing"
9
10         "github.com/tendermint/go-crypto"
11         cmn "github.com/tendermint/tmlibs/common"
12 )
13
14 type dummyConn struct {
15         *io.PipeReader
16         *io.PipeWriter
17 }
18
19 func (drw dummyConn) Close() (err error) {
20         err2 := drw.PipeWriter.CloseWithError(io.EOF)
21         err1 := drw.PipeReader.Close()
22         if err2 != nil {
23                 return err
24         }
25         return err1
26 }
27
28 // Each returned ReadWriteCloser is akin to a net.Connection
29 func makeDummyConnPair() (fooConn, barConn dummyConn) {
30         barReader, fooWriter := io.Pipe()
31         fooReader, barWriter := io.Pipe()
32         return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter}
33 }
34
35 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
36         fooConn, barConn := makeDummyConnPair()
37         fooPrvKey := crypto.GenPrivKeyEd25519()
38         fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
39         barPrvKey := crypto.GenPrivKeyEd25519()
40         barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
41
42         cmn.Parallel(
43                 func() {
44                         var err error
45                         fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
46                         if err != nil {
47                                 tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
48                                 return
49                         }
50                         remotePubBytes := fooSecConn.RemotePubKey()
51                         if !bytes.Equal(remotePubBytes[:], barPubKey[:]) {
52                                 tb.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
53                                         barPubKey, fooSecConn.RemotePubKey())
54                         }
55                 },
56                 func() {
57                         var err error
58                         barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
59                         if barSecConn == nil {
60                                 tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
61                                 return
62                         }
63                         remotePubBytes := barSecConn.RemotePubKey()
64                         if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) {
65                                 tb.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
66                                         fooPubKey, barSecConn.RemotePubKey())
67                         }
68                 })
69
70         return
71 }
72
73 func TestSecretConnectionHandshake(t *testing.T) {
74         fooSecConn, barSecConn := makeSecretConnPair(t)
75         fooSecConn.Close()
76         barSecConn.Close()
77 }
78
79 func TestSecretConnectionReadWrite(t *testing.T) {
80         fooConn, barConn := makeDummyConnPair()
81         fooWrites, barWrites := []string{}, []string{}
82         fooReads, barReads := []string{}, []string{}
83
84         // Pre-generate the things to write (for foo & bar)
85         for i := 0; i < 100; i++ {
86                 fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
87                 barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
88         }
89
90         // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
91         genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() {
92                 return func() {
93                         // Node handskae
94                         nodePrvKey := crypto.GenPrivKeyEd25519()
95                         nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
96                         if err != nil {
97                                 t.Errorf("Failed to establish SecretConnection for node: %v", err)
98                                 return
99                         }
100                         // In parallel, handle reads and writes
101                         cmn.Parallel(
102                                 func() {
103                                         // Node writes
104                                         for _, nodeWrite := range nodeWrites {
105                                                 n, err := nodeSecretConn.Write([]byte(nodeWrite))
106                                                 if err != nil {
107                                                         t.Errorf("Failed to write to nodeSecretConn: %v", err)
108                                                         return
109                                                 }
110                                                 if n != len(nodeWrite) {
111                                                         t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
112                                                         return
113                                                 }
114                                         }
115                                         nodeConn.PipeWriter.Close()
116                                 },
117                                 func() {
118                                         // Node reads
119                                         readBuffer := make([]byte, dataMaxSize)
120                                         for {
121                                                 n, err := nodeSecretConn.Read(readBuffer)
122                                                 if err == io.EOF {
123                                                         return
124                                                 } else if err != nil {
125                                                         t.Errorf("Failed to read from nodeSecretConn: %v", err)
126                                                         return
127                                                 }
128                                                 *nodeReads = append(*nodeReads, string(readBuffer[:n]))
129                                         }
130                                         nodeConn.PipeReader.Close()
131                                 })
132                 }
133         }
134
135         // Run foo & bar in parallel
136         cmn.Parallel(
137                 genNodeRunner(fooConn, fooWrites, &fooReads),
138                 genNodeRunner(barConn, barWrites, &barReads),
139         )
140
141         // A helper to ensure that the writes and reads match.
142         // Additionally, small writes (<= dataMaxSize) must be atomically read.
143         compareWritesReads := func(writes []string, reads []string) {
144                 for {
145                         // Pop next write & corresponding reads
146                         var read, write string = "", writes[0]
147                         var readCount = 0
148                         for _, readChunk := range reads {
149                                 read += readChunk
150                                 readCount += 1
151                                 if len(write) <= len(read) {
152                                         break
153                                 }
154                                 if len(write) <= dataMaxSize {
155                                         break // atomicity of small writes
156                                 }
157                         }
158                         // Compare
159                         if write != read {
160                                 t.Errorf("Expected to read %X, got %X", write, read)
161                         }
162                         // Iterate
163                         writes = writes[1:]
164                         reads = reads[readCount:]
165                         if len(writes) == 0 {
166                                 break
167                         }
168                 }
169         }
170
171         compareWritesReads(fooWrites, barReads)
172         compareWritesReads(barWrites, fooReads)
173
174 }
175
176 func BenchmarkSecretConnection(b *testing.B) {
177         b.StopTimer()
178         fooSecConn, barSecConn := makeSecretConnPair(b)
179         fooWriteText := cmn.RandStr(dataMaxSize)
180         // Consume reads from bar's reader
181         go func() {
182                 readBuffer := make([]byte, dataMaxSize)
183                 for {
184                         _, err := barSecConn.Read(readBuffer)
185                         if err == io.EOF {
186                                 return
187                         } else if err != nil {
188                                 b.Fatalf("Failed to read from barSecConn: %v", err)
189                         }
190                 }
191         }()
192
193         b.StartTimer()
194         for i := 0; i < b.N; i++ {
195                 _, err := fooSecConn.Write([]byte(fooWriteText))
196                 if err != nil {
197                         b.Fatalf("Failed to write to fooSecConn: %v", err)
198                 }
199         }
200         b.StopTimer()
201
202         fooSecConn.Close()
203         //barSecConn.Close() race condition
204 }