OSDN Git Service

new repo
[bytom/vapor.git] / p2p / connection / secret_connection_test.go
1 package connection
2
3 import (
4         "bytes"
5         "io"
6         "testing"
7
8         "github.com/tendermint/go-crypto"
9         cmn "github.com/tendermint/tmlibs/common"
10 )
11
12 type dummyConn struct {
13         *io.PipeReader
14         *io.PipeWriter
15 }
16
17 func (drw dummyConn) Close() (err error) {
18         err2 := drw.PipeWriter.CloseWithError(io.EOF)
19         err1 := drw.PipeReader.Close()
20         if err2 != nil {
21                 return err
22         }
23         return err1
24 }
25
26 // Each returned ReadWriteCloser is akin to a net.Connection
27 func makeDummyConnPair() (fooConn, barConn dummyConn) {
28         barReader, fooWriter := io.Pipe()
29         fooReader, barWriter := io.Pipe()
30         return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter}
31 }
32
33 func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
34         fooConn, barConn := makeDummyConnPair()
35         fooPrvKey := crypto.GenPrivKeyEd25519()
36         fooPubKey := fooPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
37         barPrvKey := crypto.GenPrivKeyEd25519()
38         barPubKey := barPrvKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
39
40         cmn.Parallel(
41                 func() {
42                         var err error
43                         fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
44                         if err != nil {
45                                 tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
46                                 return
47                         }
48                         remotePubBytes := fooSecConn.RemotePubKey()
49                         if !bytes.Equal(remotePubBytes[:], barPubKey[:]) {
50                                 tb.Errorf("Unexpected fooSecConn.RemotePubKey.  Expected %v, got %v",
51                                         barPubKey, fooSecConn.RemotePubKey())
52                         }
53                 },
54                 func() {
55                         var err error
56                         barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
57                         if barSecConn == nil {
58                                 tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
59                                 return
60                         }
61                         remotePubBytes := barSecConn.RemotePubKey()
62                         if !bytes.Equal(remotePubBytes[:], fooPubKey[:]) {
63                                 tb.Errorf("Unexpected barSecConn.RemotePubKey.  Expected %v, got %v",
64                                         fooPubKey, barSecConn.RemotePubKey())
65                         }
66                 })
67
68         return
69 }
70
71 func TestSecretConnectionHandshake(t *testing.T) {
72         fooSecConn, barSecConn := makeSecretConnPair(t)
73         fooSecConn.Close()
74         barSecConn.Close()
75 }
76
77 func TestSecretConnectionReadWrite(t *testing.T) {
78         fooConn, barConn := makeDummyConnPair()
79         fooWrites, barWrites := []string{}, []string{}
80         fooReads, barReads := []string{}, []string{}
81
82         // Pre-generate the things to write (for foo & bar)
83         for i := 0; i < 100; i++ {
84                 fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
85                 barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
86         }
87
88         // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
89         genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() {
90                 return func() {
91                         // Node handskae
92                         nodePrvKey := crypto.GenPrivKeyEd25519()
93                         nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
94                         if err != nil {
95                                 t.Errorf("Failed to establish SecretConnection for node: %v", err)
96                                 return
97                         }
98                         // In parallel, handle reads and writes
99                         cmn.Parallel(
100                                 func() {
101                                         // Node writes
102                                         for _, nodeWrite := range nodeWrites {
103                                                 n, err := nodeSecretConn.Write([]byte(nodeWrite))
104                                                 if err != nil {
105                                                         t.Errorf("Failed to write to nodeSecretConn: %v", err)
106                                                         return
107                                                 }
108                                                 if n != len(nodeWrite) {
109                                                         t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
110                                                         return
111                                                 }
112                                         }
113                                         nodeConn.PipeWriter.Close()
114                                 },
115                                 func() {
116                                         // Node reads
117                                         readBuffer := make([]byte, dataMaxSize)
118                                         for {
119                                                 n, err := nodeSecretConn.Read(readBuffer)
120                                                 if err == io.EOF {
121                                                         return
122                                                 } else if err != nil {
123                                                         t.Errorf("Failed to read from nodeSecretConn: %v", err)
124                                                         return
125                                                 }
126                                                 *nodeReads = append(*nodeReads, string(readBuffer[:n]))
127                                         }
128                                         nodeConn.PipeReader.Close()
129                                 })
130                 }
131         }
132
133         // Run foo & bar in parallel
134         cmn.Parallel(
135                 genNodeRunner(fooConn, fooWrites, &fooReads),
136                 genNodeRunner(barConn, barWrites, &barReads),
137         )
138
139         // A helper to ensure that the writes and reads match.
140         // Additionally, small writes (<= dataMaxSize) must be atomically read.
141         compareWritesReads := func(writes []string, reads []string) {
142                 for {
143                         // Pop next write & corresponding reads
144                         var read, write string = "", writes[0]
145                         var readCount = 0
146                         for _, readChunk := range reads {
147                                 read += readChunk
148                                 readCount++
149                                 if len(write) <= len(read) {
150                                         break
151                                 }
152                                 if len(write) <= dataMaxSize {
153                                         break // atomicity of small writes
154                                 }
155                         }
156                         // Compare
157                         if write != read {
158                                 t.Errorf("Expected to read %X, got %X", write, read)
159                         }
160                         // Iterate
161                         writes = writes[1:]
162                         reads = reads[readCount:]
163                         if len(writes) == 0 {
164                                 break
165                         }
166                 }
167         }
168
169         compareWritesReads(fooWrites, barReads)
170         compareWritesReads(barWrites, fooReads)
171
172 }
173
174 func BenchmarkSecretConnection(b *testing.B) {
175         b.StopTimer()
176         fooSecConn, barSecConn := makeSecretConnPair(b)
177         fooWriteText := cmn.RandStr(dataMaxSize)
178         // Consume reads from bar's reader
179         go func() {
180                 readBuffer := make([]byte, dataMaxSize)
181                 for {
182                         _, err := barSecConn.Read(readBuffer)
183                         if err == io.EOF {
184                                 return
185                         } else if err != nil {
186                                 b.Fatalf("Failed to read from barSecConn: %v", err)
187                         }
188                 }
189         }()
190
191         b.StartTimer()
192         for i := 0; i < b.N; i++ {
193                 _, err := fooSecConn.Write([]byte(fooWriteText))
194                 if err != nil {
195                         b.Fatalf("Failed to write to fooSecConn: %v", err)
196                 }
197         }
198         b.StopTimer()
199
200         fooSecConn.Close()
201         //barSecConn.Close() race condition
202 }