OSDN Git Service

f125dfd0e5d4836111528acabcae1a66aa2fcefe
[bytom/bytom.git] / p2p / switch_test.go
1 // +build !network
2
3 package p2p
4
5 import (
6         "bytes"
7         "fmt"
8         "net"
9         "sync"
10         "testing"
11         "time"
12
13         "github.com/stretchr/testify/assert"
14         "github.com/stretchr/testify/require"
15         crypto "github.com/tendermint/go-crypto"
16         wire "github.com/tendermint/go-wire"
17
18         cfg "github.com/bytom/config"
19         "github.com/tendermint/tmlibs/log"
20 )
21
22 var (
23         config *cfg.P2PConfig
24 )
25
26 func init() {
27         config = cfg.DefaultP2PConfig()
28         config.PexReactor = true
29 }
30
31 type PeerMessage struct {
32         PeerKey string
33         Bytes   []byte
34         Counter int
35 }
36
37 type TestReactor struct {
38         BaseReactor
39
40         mtx          sync.Mutex
41         channels     []*ChannelDescriptor
42         peersAdded   []*Peer
43         peersRemoved []*Peer
44         logMessages  bool
45         msgsCounter  int
46         msgsReceived map[byte][]PeerMessage
47 }
48
49 func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor {
50         tr := &TestReactor{
51                 channels:     channels,
52                 logMessages:  logMessages,
53                 msgsReceived: make(map[byte][]PeerMessage),
54         }
55         tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
56         tr.SetLogger(log.TestingLogger())
57         return tr
58 }
59
60 func (tr *TestReactor) GetChannels() []*ChannelDescriptor {
61         return tr.channels
62 }
63
64 func (tr *TestReactor) AddPeer(peer *Peer) {
65         tr.mtx.Lock()
66         defer tr.mtx.Unlock()
67         tr.peersAdded = append(tr.peersAdded, peer)
68 }
69
70 func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) {
71         tr.mtx.Lock()
72         defer tr.mtx.Unlock()
73         tr.peersRemoved = append(tr.peersRemoved, peer)
74 }
75
76 func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {
77         if tr.logMessages {
78                 tr.mtx.Lock()
79                 defer tr.mtx.Unlock()
80                 //fmt.Printf("Received: %X, %X\n", chID, msgBytes)
81                 tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter})
82                 tr.msgsCounter++
83         }
84 }
85
86 func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
87         tr.mtx.Lock()
88         defer tr.mtx.Unlock()
89         return tr.msgsReceived[chID]
90 }
91
92 //-----------------------------------------------------------------------------
93
94 // convenience method for creating two switches connected to each other.
95 // XXX: note this uses net.Pipe and not a proper TCP conn
96 func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
97         // Create two switches that will be interconnected.
98         switches := MakeConnectedSwitches(config, 2, initSwitch, Connect2Switches)
99         return switches[0], switches[1]
100 }
101
102 func initSwitchFunc(i int, sw *Switch) *Switch {
103         // Make two reactors of two channels each
104         sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
105                 &ChannelDescriptor{ID: byte(0x00), Priority: 10},
106                 &ChannelDescriptor{ID: byte(0x01), Priority: 10},
107         }, true))
108         sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
109                 &ChannelDescriptor{ID: byte(0x02), Priority: 10},
110                 &ChannelDescriptor{ID: byte(0x03), Priority: 10},
111         }, true))
112         return sw
113 }
114
115 func TestSwitches(t *testing.T) {
116         s1, s2 := makeSwitchPair(t, initSwitchFunc)
117         defer s1.Stop()
118         defer s2.Stop()
119
120         if s1.Peers().Size() != 1 {
121                 t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size())
122         }
123         if s2.Peers().Size() != 1 {
124                 t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size())
125         }
126
127         // Lets send some messages
128         ch0Msg := "channel zero"
129         ch1Msg := "channel foo"
130         ch2Msg := "channel bar"
131
132         s1.Broadcast(byte(0x00), ch0Msg)
133         s1.Broadcast(byte(0x01), ch1Msg)
134         s1.Broadcast(byte(0x02), ch2Msg)
135
136         // Wait for things to settle...
137         time.Sleep(5000 * time.Millisecond)
138
139         // Check message on ch0
140         ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00))
141         if len(ch0Msgs) != 1 {
142                 t.Errorf("Expected to have received 1 message in ch0")
143         }
144         if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) {
145                 t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes)
146         }
147
148         // Check message on ch1
149         ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01))
150         if len(ch1Msgs) != 1 {
151                 t.Errorf("Expected to have received 1 message in ch1")
152         }
153         if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) {
154                 t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes)
155         }
156
157         // Check message on ch2
158         ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02))
159         if len(ch2Msgs) != 1 {
160                 t.Errorf("Expected to have received 1 message in ch2")
161         }
162         if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) {
163                 t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes)
164         }
165
166 }
167
168 func TestConnAddrFilter(t *testing.T) {
169         s1 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
170         s2 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
171
172         c1, c2 := net.Pipe()
173
174         s1.SetAddrFilter(func(addr net.Addr) error {
175                 if addr.String() == c1.RemoteAddr().String() {
176                         return fmt.Errorf("Error: pipe is blacklisted")
177                 }
178                 return nil
179         })
180
181         // connect to good peer
182         go func() {
183                 s1.addPeerWithConnection(c1)
184         }()
185         go func() {
186                 s2.addPeerWithConnection(c2)
187         }()
188
189         // Wait for things to happen, peers to get added...
190         time.Sleep(100 * time.Millisecond * time.Duration(4))
191
192         defer s1.Stop()
193         defer s2.Stop()
194         if s1.Peers().Size() != 0 {
195                 t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
196         }
197         if s2.Peers().Size() != 0 {
198                 t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
199         }
200 }
201
202 func TestConnPubKeyFilter(t *testing.T) {
203         s1 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
204         s2 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
205
206         c1, c2 := net.Pipe()
207
208         // set pubkey filter
209         s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error {
210                 if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) {
211                         return fmt.Errorf("Error: pipe is blacklisted")
212                 }
213                 return nil
214         })
215
216         // connect to good peer
217         go func() {
218                 s1.addPeerWithConnection(c1)
219         }()
220         go func() {
221                 s2.addPeerWithConnection(c2)
222         }()
223
224         // Wait for things to happen, peers to get added...
225         time.Sleep(100 * time.Millisecond * time.Duration(4))
226
227         defer s1.Stop()
228         defer s2.Stop()
229         if s1.Peers().Size() != 0 {
230                 t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
231         }
232         if s2.Peers().Size() != 0 {
233                 t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
234         }
235 }
236
237 func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
238         assert, require := assert.New(t), require.New(t)
239
240         sw := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
241         sw.Start()
242         defer sw.Stop()
243
244         // simulate remote peer
245         rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig(config)}
246         rp.Start()
247         defer rp.Stop()
248
249         peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config)
250         require.Nil(err)
251         err = sw.AddPeer(peer)
252         require.Nil(err)
253
254         // simulate failure by closing connection
255         peer.CloseConn()
256
257         time.Sleep(100 * time.Millisecond)
258
259         assert.Zero(sw.Peers().Size())
260         assert.False(peer.IsRunning())
261 }
262
263 func TestSwitchReconnectsToPersistentPeer(t *testing.T) {
264         assert, require := assert.New(t), require.New(t)
265
266         sw := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
267         sw.Start()
268         defer sw.Stop()
269
270         // simulate remote peer
271         rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig(config)}
272         rp.Start()
273         defer rp.Stop()
274
275         peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config)
276         peer.makePersistent()
277         require.Nil(err)
278         err = sw.AddPeer(peer)
279         require.Nil(err)
280
281         // simulate failure by closing connection
282         peer.CloseConn()
283
284         // TODO: actually detect the disconnection and wait for reconnect
285         time.Sleep(100 * time.Millisecond)
286
287         assert.NotZero(sw.Peers().Size())
288         assert.False(peer.IsRunning())
289 }
290
291 func BenchmarkSwitches(b *testing.B) {
292         b.StopTimer()
293
294         s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch {
295                 // Make bar reactors of bar channels each
296                 sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
297                         &ChannelDescriptor{ID: byte(0x00), Priority: 10},
298                         &ChannelDescriptor{ID: byte(0x01), Priority: 10},
299                 }, false))
300                 sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
301                         &ChannelDescriptor{ID: byte(0x02), Priority: 10},
302                         &ChannelDescriptor{ID: byte(0x03), Priority: 10},
303                 }, false))
304                 return sw
305         })
306         defer s1.Stop()
307         defer s2.Stop()
308
309         // Allow time for goroutines to boot up
310         time.Sleep(1000 * time.Millisecond)
311         b.StartTimer()
312
313         numSuccess, numFailure := 0, 0
314
315         // Send random message from foo channel to another
316         for i := 0; i < b.N; i++ {
317                 chID := byte(i % 4)
318                 successChan := s1.Broadcast(chID, "test data")
319                 for s := range successChan {
320                         if s {
321                                 numSuccess++
322                         } else {
323                                 numFailure++
324                         }
325                 }
326         }
327
328         b.Logf("success: %v, failure: %v", numSuccess, numFailure)
329
330         // Allow everything to flush before stopping switches & closing connections.
331         b.StopTimer()
332         time.Sleep(1000 * time.Millisecond)
333 }