OSDN Git Service

Merge pull request #41 from Bytom/dev
[bytom/vapor.git] / p2p / connection / connection_test.go
1 package connection
2
3 import (
4         "net"
5         "testing"
6         "time"
7
8         "github.com/stretchr/testify/assert"
9         "github.com/stretchr/testify/require"
10 )
11
12 func createMConnection(conn net.Conn) *MConnection {
13         onReceive := func(chID byte, msgBytes []byte) {
14         }
15         onError := func(r interface{}) {
16         }
17         c := createMConnectionWithCallbacks(conn, onReceive, onError)
18         return c
19 }
20
21 func createMConnectionWithCallbacks(conn net.Conn, onReceive func(chID byte, msgBytes []byte), onError func(r interface{})) *MConnection {
22         chDescs := []*ChannelDescriptor{&ChannelDescriptor{ID: 0x01, Priority: 1, SendQueueCapacity: 1}}
23         c := NewMConnectionWithConfig(conn, chDescs, onReceive, onError, DefaultMConnConfig())
24         return c
25 }
26
27 func TestMConnectionSend(t *testing.T) {
28         assert, require := assert.New(t), require.New(t)
29
30         server, client := net.Pipe()
31         defer server.Close()
32         defer client.Close()
33
34         mconn := createMConnection(client)
35         _, err := mconn.Start()
36         require.Nil(err)
37         defer mconn.Stop()
38
39         msg := "Ant-Man"
40         assert.True(mconn.Send(0x01, msg))
41         // Note: subsequent Send/TrySend calls could pass because we are reading from
42         // the send queue in a separate goroutine.
43         server.Read(make([]byte, len(msg)))
44         assert.True(mconn.CanSend(0x01))
45
46         msg = "Spider-Man"
47         assert.True(mconn.TrySend(0x01, msg))
48         server.Read(make([]byte, len(msg)))
49
50         assert.False(mconn.CanSend(0x05), "CanSend should return false because channel is unknown")
51         assert.False(mconn.Send(0x05, "Absorbing Man"), "Send should return false because channel is unknown")
52 }
53
54 func TestMConnectionReceive(t *testing.T) {
55         assert, require := assert.New(t), require.New(t)
56
57         server, client := net.Pipe()
58         defer server.Close()
59         defer client.Close()
60
61         receivedCh := make(chan []byte)
62         errorsCh := make(chan interface{})
63         onReceive := func(chID byte, msgBytes []byte) {
64                 receivedCh <- msgBytes
65         }
66         onError := func(r interface{}) {
67                 errorsCh <- r
68         }
69         mconn1 := createMConnectionWithCallbacks(client, onReceive, onError)
70         _, err := mconn1.Start()
71         require.Nil(err)
72         defer mconn1.Stop()
73
74         mconn2 := createMConnection(server)
75         _, err = mconn2.Start()
76         require.Nil(err)
77         defer mconn2.Stop()
78
79         msg := "Cyclops"
80         assert.True(mconn2.Send(0x01, msg))
81
82         select {
83         case receivedBytes := <-receivedCh:
84                 assert.Equal([]byte(msg), receivedBytes[2:]) // first 3 bytes are internal
85         case err := <-errorsCh:
86                 t.Fatalf("Expected %s, got %+v", msg, err)
87         case <-time.After(500 * time.Millisecond):
88                 t.Fatalf("Did not receive %s message in 500ms", msg)
89         }
90 }
91
92 func TestMConnectionStopsAndReturnsError(t *testing.T) {
93         assert, require := assert.New(t), require.New(t)
94
95         server, client := net.Pipe()
96         defer server.Close()
97         defer client.Close()
98
99         receivedCh := make(chan []byte)
100         errorsCh := make(chan interface{})
101         onReceive := func(chID byte, msgBytes []byte) {
102                 receivedCh <- msgBytes
103         }
104         onError := func(r interface{}) {
105                 errorsCh <- r
106         }
107         mconn := createMConnectionWithCallbacks(client, onReceive, onError)
108         _, err := mconn.Start()
109         require.Nil(err)
110         defer mconn.Stop()
111
112         client.Close()
113
114         select {
115         case receivedBytes := <-receivedCh:
116                 t.Fatalf("Expected error, got %v", receivedBytes)
117         case err := <-errorsCh:
118                 assert.NotNil(err)
119                 assert.False(mconn.IsRunning())
120         case <-time.After(500 * time.Millisecond):
121                 t.Fatal("Did not receive error in 500ms")
122         }
123 }