1 // Copyright (c) 2013-2016 The btcsuite developers
2 // Use of this source code is governed by an ISC
3 // license that can be found in the LICENSE file.
16 "github.com/btcsuite/btcd/chaincfg/chainhash"
17 "github.com/davecgh/go-spew/spew"
20 // makeHeader is a convenience function to make a message header in the form of
21 // a byte slice. It is used to force errors when reading messages.
22 func makeHeader(btcnet BitcoinNet, command string,
23 payloadLen uint32, checksum uint32) []byte {
25 // The length of a bitcoin message header is 24 bytes.
26 // 4 byte magic number of the bitcoin network + 12 byte command + 4 byte
27 // payload length + 4 byte checksum.
28 buf := make([]byte, 24)
29 binary.LittleEndian.PutUint32(buf, uint32(btcnet))
30 copy(buf[4:], []byte(command))
31 binary.LittleEndian.PutUint32(buf[16:], payloadLen)
32 binary.LittleEndian.PutUint32(buf[20:], checksum)
36 // TestMessage tests the Read/WriteMessage and Read/WriteMessageN API.
37 func TestMessage(t *testing.T) {
38 pver := ProtocolVersion
40 // Create the various types of messages to test.
43 addrYou := &net.TCPAddr{IP: net.ParseIP("192.168.0.1"), Port: 8333}
44 you := NewNetAddress(addrYou, SFNodeNetwork)
45 you.Timestamp = time.Time{} // Version message has zero value timestamp.
46 addrMe := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8333}
47 me := NewNetAddress(addrMe, SFNodeNetwork)
48 me.Timestamp = time.Time{} // Version message has zero value timestamp.
49 msgVersion := NewMsgVersion(me, you, 123123, 0)
51 msgVerack := NewMsgVerAck()
52 msgGetAddr := NewMsgGetAddr()
53 msgAddr := NewMsgAddr()
54 msgGetBlocks := NewMsgGetBlocks(&chainhash.Hash{})
57 msgGetData := NewMsgGetData()
58 msgNotFound := NewMsgNotFound()
60 msgPing := NewMsgPing(123123)
61 msgPong := NewMsgPong(123123)
62 msgGetHeaders := NewMsgGetHeaders()
63 msgHeaders := NewMsgHeaders()
64 msgAlert := NewMsgAlert([]byte("payload"), []byte("signature"))
65 msgMemPool := NewMsgMemPool()
66 msgFilterAdd := NewMsgFilterAdd([]byte{0x01})
67 msgFilterClear := NewMsgFilterClear()
68 msgFilterLoad := NewMsgFilterLoad([]byte{0x01}, 10, 0, BloomUpdateNone)
69 bh := NewBlockHeader(1, &chainhash.Hash{}, &chainhash.Hash{}, 0, 0)
70 msgMerkleBlock := NewMsgMerkleBlock(bh)
71 msgReject := NewMsgReject("block", RejectDuplicate, "duplicate block")
74 in Message // Value to encode
75 out Message // Expected decoded value
76 pver uint32 // Protocol version for wire encoding
77 btcnet BitcoinNet // Network to use for wire encoding
78 bytes int // Expected num bytes read/written
80 {msgVersion, msgVersion, pver, MainNet, 125},
81 {msgVerack, msgVerack, pver, MainNet, 24},
82 {msgGetAddr, msgGetAddr, pver, MainNet, 24},
83 {msgAddr, msgAddr, pver, MainNet, 25},
84 {msgGetBlocks, msgGetBlocks, pver, MainNet, 61},
85 {msgBlock, msgBlock, pver, MainNet, 239},
86 {msgInv, msgInv, pver, MainNet, 25},
87 {msgGetData, msgGetData, pver, MainNet, 25},
88 {msgNotFound, msgNotFound, pver, MainNet, 25},
89 {msgTx, msgTx, pver, MainNet, 34},
90 {msgPing, msgPing, pver, MainNet, 32},
91 {msgPong, msgPong, pver, MainNet, 32},
92 {msgGetHeaders, msgGetHeaders, pver, MainNet, 61},
93 {msgHeaders, msgHeaders, pver, MainNet, 25},
94 {msgAlert, msgAlert, pver, MainNet, 42},
95 {msgMemPool, msgMemPool, pver, MainNet, 24},
96 {msgFilterAdd, msgFilterAdd, pver, MainNet, 26},
97 {msgFilterClear, msgFilterClear, pver, MainNet, 24},
98 {msgFilterLoad, msgFilterLoad, pver, MainNet, 35},
99 {msgMerkleBlock, msgMerkleBlock, pver, MainNet, 110},
100 {msgReject, msgReject, pver, MainNet, 79},
103 t.Logf("Running %d tests", len(tests))
104 for i, test := range tests {
105 // Encode to wire format.
107 nw, err := WriteMessageN(&buf, test.in, test.pver, test.btcnet)
109 t.Errorf("WriteMessage #%d error %v", i, err)
113 // Ensure the number of bytes written match the expected value.
114 if nw != test.bytes {
115 t.Errorf("WriteMessage #%d unexpected num bytes "+
116 "written - got %d, want %d", i, nw, test.bytes)
119 // Decode from wire format.
120 rbuf := bytes.NewReader(buf.Bytes())
121 nr, msg, _, err := ReadMessageN(rbuf, test.pver, test.btcnet)
123 t.Errorf("ReadMessage #%d error %v, msg %v", i, err,
127 if !reflect.DeepEqual(msg, test.out) {
128 t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
129 spew.Sdump(msg), spew.Sdump(test.out))
133 // Ensure the number of bytes read match the expected value.
134 if nr != test.bytes {
135 t.Errorf("ReadMessage #%d unexpected num bytes read - "+
136 "got %d, want %d", i, nr, test.bytes)
140 // Do the same thing for Read/WriteMessage, but ignore the bytes since
141 // they don't return them.
142 t.Logf("Running %d tests", len(tests))
143 for i, test := range tests {
144 // Encode to wire format.
146 err := WriteMessage(&buf, test.in, test.pver, test.btcnet)
148 t.Errorf("WriteMessage #%d error %v", i, err)
152 // Decode from wire format.
153 rbuf := bytes.NewReader(buf.Bytes())
154 msg, _, err := ReadMessage(rbuf, test.pver, test.btcnet)
156 t.Errorf("ReadMessage #%d error %v, msg %v", i, err,
160 if !reflect.DeepEqual(msg, test.out) {
161 t.Errorf("ReadMessage #%d\n got: %v want: %v", i,
162 spew.Sdump(msg), spew.Sdump(test.out))
168 // TestReadMessageWireErrors performs negative tests against wire decoding into
169 // concrete messages to confirm error paths work correctly.
170 func TestReadMessageWireErrors(t *testing.T) {
171 pver := ProtocolVersion
174 // Ensure message errors are as expected with no function specified.
175 wantErr := "something bad happened"
176 testErr := MessageError{Description: wantErr}
177 if testErr.Error() != wantErr {
178 t.Errorf("MessageError: wrong error - got %v, want %v",
179 testErr.Error(), wantErr)
182 // Ensure message errors are as expected with a function specified.
184 testErr = MessageError{Func: wantFunc, Description: wantErr}
185 if testErr.Error() != wantFunc+": "+wantErr {
186 t.Errorf("MessageError: wrong error - got %v, want %v",
187 testErr.Error(), wantErr)
190 // Wire encoded bytes for main and testnet3 networks magic identifiers.
191 testNet3Bytes := makeHeader(TestNet3, "", 0, 0)
193 // Wire encoded bytes for a message that exceeds max overall message
195 mpl := uint32(MaxMessagePayload)
196 exceedMaxPayloadBytes := makeHeader(btcnet, "getaddr", mpl+1, 0)
198 // Wire encoded bytes for a command which is invalid utf-8.
199 badCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
200 badCommandBytes[4] = 0x81
202 // Wire encoded bytes for a command which is valid, but not supported.
203 unsupportedCommandBytes := makeHeader(btcnet, "bogus", 0, 0)
205 // Wire encoded bytes for a message which exceeds the max payload for
206 // a specific message type.
207 exceedTypePayloadBytes := makeHeader(btcnet, "getaddr", 1, 0)
209 // Wire encoded bytes for a message which does not deliver the full
210 // payload according to the header length.
211 shortPayloadBytes := makeHeader(btcnet, "version", 115, 0)
213 // Wire encoded bytes for a message with a bad checksum.
214 badChecksumBytes := makeHeader(btcnet, "version", 2, 0xbeef)
215 badChecksumBytes = append(badChecksumBytes, []byte{0x0, 0x0}...)
217 // Wire encoded bytes for a message which has a valid header, but is
218 // the wrong format. An addr starts with a varint of the number of
219 // contained in the message. Claim there is two, but don't provide
220 // them. At the same time, forge the header fields so the message is
221 // otherwise accurate.
222 badMessageBytes := makeHeader(btcnet, "addr", 1, 0xeaadc31c)
223 badMessageBytes = append(badMessageBytes, 0x2)
225 // Wire encoded bytes for a message which the header claims has 15k
226 // bytes of data to discard.
227 discardBytes := makeHeader(btcnet, "bogus", 15*1024, 0)
230 buf []byte // Wire encoding
231 pver uint32 // Protocol version for wire encoding
232 btcnet BitcoinNet // Bitcoin network for wire encoding
233 max int // Max size of fixed buffer to induce errors
234 readErr error // Expected read error
235 bytes int // Expected num bytes read
237 // Latest protocol version with intentional read errors.
249 // Wrong network. Want MainNet, but giving TestNet3.
259 // Exceed max overall message payload length.
261 exceedMaxPayloadBytes,
264 len(exceedMaxPayloadBytes),
269 // Invalid UTF-8 command.
274 len(badCommandBytes),
279 // Valid, but unsupported command.
281 unsupportedCommandBytes,
284 len(unsupportedCommandBytes),
289 // Exceed max allowed payload for a message of a specific type.
291 exceedTypePayloadBytes,
294 len(exceedTypePayloadBytes),
299 // Message with a payload shorter than the header indicates.
304 len(shortPayloadBytes),
309 // Message with a bad checksum.
314 len(badChecksumBytes),
319 // Message with a valid header, but wrong format.
324 len(badMessageBytes),
329 // 15k bytes of data to discard.
340 t.Logf("Running %d tests", len(tests))
341 for i, test := range tests {
342 // Decode from wire format.
343 r := newFixedReader(test.max, test.buf)
344 nr, _, _, err := ReadMessageN(r, test.pver, test.btcnet)
345 if reflect.TypeOf(err) != reflect.TypeOf(test.readErr) {
346 t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
347 "want: %T", i, err, err, test.readErr)
351 // Ensure the number of bytes written match the expected value.
352 if nr != test.bytes {
353 t.Errorf("ReadMessage #%d unexpected num bytes read - "+
354 "got %d, want %d", i, nr, test.bytes)
357 // For errors which are not of type MessageError, check them for
359 if _, ok := err.(*MessageError); !ok {
360 if err != test.readErr {
361 t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
362 "want: %v <%T>", i, err, err,
363 test.readErr, test.readErr)
370 // TestWriteMessageWireErrors performs negative tests against wire encoding from
371 // concrete messages to confirm error paths work correctly.
372 func TestWriteMessageWireErrors(t *testing.T) {
373 pver := ProtocolVersion
375 wireErr := &MessageError{}
377 // Fake message with a command that is too long.
378 badCommandMsg := &fakeMessage{command: "somethingtoolong"}
380 // Fake message with a problem during encoding
381 encodeErrMsg := &fakeMessage{forceEncodeErr: true}
383 // Fake message that has payload which exceeds max overall message size.
384 exceedOverallPayload := make([]byte, MaxMessagePayload+1)
385 exceedOverallPayloadErrMsg := &fakeMessage{payload: exceedOverallPayload}
387 // Fake message that has payload which exceeds max allowed per message.
388 exceedPayload := make([]byte, 1)
389 exceedPayloadErrMsg := &fakeMessage{payload: exceedPayload, forceLenErr: true}
391 // Fake message that is used to force errors in the header and payload
393 bogusPayload := []byte{0x01, 0x02, 0x03, 0x04}
394 bogusMsg := &fakeMessage{command: "bogus", payload: bogusPayload}
397 msg Message // Message to encode
398 pver uint32 // Protocol version for wire encoding
399 btcnet BitcoinNet // Bitcoin network for wire encoding
400 max int // Max size of fixed buffer to induce errors
401 err error // Expected error
402 bytes int // Expected num bytes written
405 {badCommandMsg, pver, btcnet, 0, wireErr, 0},
406 // Force error in payload encode.
407 {encodeErrMsg, pver, btcnet, 0, wireErr, 0},
408 // Force error due to exceeding max overall message payload size.
409 {exceedOverallPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
410 // Force error due to exceeding max payload for message type.
411 {exceedPayloadErrMsg, pver, btcnet, 0, wireErr, 0},
412 // Force error in header write.
413 {bogusMsg, pver, btcnet, 0, io.ErrShortWrite, 0},
414 // Force error in payload write.
415 {bogusMsg, pver, btcnet, 24, io.ErrShortWrite, 24},
418 t.Logf("Running %d tests", len(tests))
419 for i, test := range tests {
420 // Encode wire format.
421 w := newFixedWriter(test.max)
422 nw, err := WriteMessageN(w, test.msg, test.pver, test.btcnet)
423 if reflect.TypeOf(err) != reflect.TypeOf(test.err) {
424 t.Errorf("WriteMessage #%d wrong error got: %v <%T>, "+
425 "want: %T", i, err, err, test.err)
429 // Ensure the number of bytes written match the expected value.
430 if nw != test.bytes {
431 t.Errorf("WriteMessage #%d unexpected num bytes "+
432 "written - got %d, want %d", i, nw, test.bytes)
435 // For errors which are not of type MessageError, check them for
437 if _, ok := err.(*MessageError); !ok {
439 t.Errorf("ReadMessage #%d wrong error got: %v <%T>, "+
440 "want: %v <%T>", i, err, err,