OSDN Git Service

Merge pull request #1072 from Bytom/prod
[bytom/bytom.git] / p2p / discover / udp.go
1 // Copyright 2016 The go-ethereum Authors
2 // This file is part of the go-ethereum library.
3 //
4 // The go-ethereum library is free software: you can redistribute it and/or modify
5 // it under the terms of the GNU Lesser General Public License as published by
6 // the Free Software Foundation, either version 3 of the License, or
7 // (at your option) any later version.
8 //
9 // The go-ethereum library is distributed in the hope that it will be useful,
10 // but WITHOUT ANY WARRANTY; without even the implied warranty of
11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 // GNU Lesser General Public License for more details.
13 //
14 // You should have received a copy of the GNU Lesser General Public License
15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
16
17 package discover
18
19 import (
20         "bytes"
21         "crypto/ecdsa"
22         "errors"
23         "fmt"
24         "net"
25         "time"
26
27         log "github.com/sirupsen/logrus"
28         "github.com/tendermint/go-crypto"
29         "github.com/tendermint/go-wire"
30
31         "github.com/bytom/common"
32         "github.com/bytom/p2p/netutil"
33 )
34
35 const Version = 4
36
37 // Errors
38 var (
39         errPacketTooSmall   = errors.New("too small")
40         errBadPrefix        = errors.New("bad prefix")
41         errExpired          = errors.New("expired")
42         errUnsolicitedReply = errors.New("unsolicited reply")
43         errUnknownNode      = errors.New("unknown node")
44         errTimeout          = errors.New("RPC timeout")
45         errClockWarp        = errors.New("reply deadline too far in the future")
46         errClosed           = errors.New("socket closed")
47 )
48
49 // Timeouts
50 const (
51         respTimeout = 500 * time.Millisecond
52         queryDelay  = 1000 * time.Millisecond
53         expiration  = 20 * time.Second
54
55         ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
56         ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
57         driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
58 )
59
60 // ReadPacket is sent to the unhandled channel when it could not be processed
61 type ReadPacket struct {
62         Data []byte
63         Addr *net.UDPAddr
64 }
65
66 // Config holds Table-related settings.
67 type Config struct {
68         // These settings are required and configure the UDP listener:
69         PrivateKey *ecdsa.PrivateKey
70
71         // These settings are optional:
72         AnnounceAddr *net.UDPAddr // local address announced in the DHT
73         NodeDBPath   string       // if set, the node database is stored at this filesystem location
74         //NetRestrict  *netutil.Netlist  // network whitelist
75         Bootnodes []*Node           // list of bootstrap nodes
76         Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
77 }
78
79 // RPC request structures
80 type (
81         ping struct {
82                 Version    uint
83                 From, To   rpcEndpoint
84                 Expiration uint64
85
86                 // v5
87                 Topics []Topic
88
89                 // Ignore additional fields (for forward compatibility).
90                 Rest []byte
91         }
92
93         // pong is the reply to ping.
94         pong struct {
95                 // This field should mirror the UDP envelope address
96                 // of the ping packet, which provides a way to discover the
97                 // the external address (after NAT).
98                 To rpcEndpoint
99
100                 ReplyTok   []byte // This contains the hash of the ping packet.
101                 Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
102
103                 // v5
104                 TopicHash    common.Hash
105                 TicketSerial uint32
106                 WaitPeriods  []uint32
107
108                 // Ignore additional fields (for forward compatibility).
109                 Rest []byte
110         }
111
112         // findnode is a query for nodes close to the given target.
113         findnode struct {
114                 Target     NodeID // doesn't need to be an actual public key
115                 Expiration uint64
116                 // Ignore additional fields (for forward compatibility).
117                 Rest []byte
118         }
119
120         // findnode is a query for nodes close to the given target.
121         findnodeHash struct {
122                 Target     common.Hash
123                 Expiration uint64
124                 // Ignore additional fields (for forward compatibility).
125                 Rest []byte
126         }
127
128         // reply to findnode
129         neighbors struct {
130                 Nodes      []rpcNode
131                 Expiration uint64
132                 // Ignore additional fields (for forward compatibility).
133                 Rest []byte
134         }
135
136         topicRegister struct {
137                 Topics []Topic
138                 Idx    uint
139                 Pong   []byte
140         }
141
142         topicQuery struct {
143                 Topic      Topic
144                 Expiration uint64
145         }
146
147         // reply to topicQuery
148         topicNodes struct {
149                 Echo  common.Hash
150                 Nodes []rpcNode
151         }
152
153         rpcNode struct {
154                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
155                 UDP uint16 // for discovery protocol
156                 TCP uint16 // for RLPx protocol
157                 ID  NodeID
158         }
159
160         rpcEndpoint struct {
161                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
162                 UDP uint16 // for discovery protocol
163                 TCP uint16 // for RLPx protocol
164         }
165 )
166
167 var (
168         versionPrefix     = []byte("bytom discovery")
169         versionPrefixSize = len(versionPrefix)
170         nodeIDSize        = 32
171         sigSize           = 520 / 8
172         headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
173 )
174
175 // Neighbors replies are sent across multiple packets to
176 // stay below the 1280 byte limit. We compute the maximum number
177 // of entries by stuffing a packet until it grows too large.
178 var maxNeighbors = func() int {
179         p := neighbors{Expiration: ^uint64(0)}
180         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
181         for n := 0; ; n++ {
182                 p.Nodes = append(p.Nodes, maxSizeNode)
183                 var size int
184                 var err error
185                 b := new(bytes.Buffer)
186                 wire.WriteJSON(p, b, &size, &err)
187                 if err != nil {
188                         // If this ever happens, it will be caught by the unit tests.
189                         panic("cannot encode: " + err.Error())
190                 }
191                 if headSize+size+1 >= 1280 {
192                         return n
193                 }
194         }
195 }()
196
197 var maxTopicNodes = func() int {
198         p := topicNodes{}
199         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
200         for n := 0; ; n++ {
201                 p.Nodes = append(p.Nodes, maxSizeNode)
202                 var size int
203                 var err error
204                 b := new(bytes.Buffer)
205                 wire.WriteJSON(p, b, &size, &err)
206                 if err != nil {
207                         // If this ever happens, it will be caught by the unit tests.
208                         panic("cannot encode: " + err.Error())
209                 }
210                 if headSize+size+1 >= 1280 {
211                         return n
212                 }
213         }
214 }()
215
216 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
217         ip := addr.IP.To4()
218         if ip == nil {
219                 ip = addr.IP.To16()
220         }
221         return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
222 }
223
224 func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
225         return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
226 }
227
228 func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
229         if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
230                 return nil, err
231         }
232         n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
233         err := n.validateComplete()
234         return n, err
235 }
236
237 func nodeToRPC(n *Node) rpcNode {
238         return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
239 }
240
241 type ingressPacket struct {
242         remoteID   NodeID
243         remoteAddr *net.UDPAddr
244         ev         nodeEvent
245         hash       []byte
246         data       interface{} // one of the RPC structs
247         rawData    []byte
248 }
249
250 type conn interface {
251         ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
252         WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
253         Close() error
254         LocalAddr() net.Addr
255 }
256
257 // udp implements the RPC protocol.
258 type udp struct {
259         conn        conn
260         priv        *crypto.PrivKeyEd25519
261         ourEndpoint rpcEndpoint
262         //nat         nat.Interface
263         net *Network
264 }
265
266 // ListenUDP returns a new table that listens for UDP packets on laddr.
267 func ListenUDP(priv *crypto.PrivKeyEd25519, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
268         transport, err := listenUDP(priv, conn, realaddr)
269         if err != nil {
270                 return nil, err
271         }
272         net, err := newNetwork(transport, priv.PubKey().Unwrap().(crypto.PubKeyEd25519), nodeDBPath, netrestrict)
273         if err != nil {
274                 return nil, err
275         }
276         log.Info("UDP listener up v5", "net", net.tab.self)
277         transport.net = net
278         go transport.readLoop()
279         return net, nil
280 }
281
282 func listenUDP(priv *crypto.PrivKeyEd25519, conn conn, realaddr *net.UDPAddr) (*udp, error) {
283         return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
284 }
285
286 func (t *udp) localAddr() *net.UDPAddr {
287         return t.conn.LocalAddr().(*net.UDPAddr)
288 }
289
290 func (t *udp) Close() {
291         t.conn.Close()
292 }
293
294 func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
295         hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
296         return hash
297 }
298
299 func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
300         hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
301                 Version:    Version,
302                 From:       t.ourEndpoint,
303                 To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
304                 Expiration: uint64(time.Now().Add(expiration).Unix()),
305                 Topics:     topics,
306         })
307         return hash
308 }
309
310 func (t *udp) sendFindnode(remote *Node, target NodeID) {
311         t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
312                 Target:     target,
313                 Expiration: uint64(time.Now().Add(expiration).Unix()),
314         })
315 }
316
317 func (t *udp) sendNeighbours(remote *Node, results []*Node) {
318         // Send neighbors in chunks with at most maxNeighbors per packet
319         // to stay below the 1280 byte limit.
320         p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
321         for i, result := range results {
322                 p.Nodes = append(p.Nodes, nodeToRPC(result))
323                 if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
324                         t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
325                         p.Nodes = p.Nodes[:0]
326                 }
327         }
328 }
329
330 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
331         t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
332                 Target:     common.Hash(target),
333                 Expiration: uint64(time.Now().Add(expiration).Unix()),
334         })
335 }
336
337 func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
338         t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
339                 Topics: topics,
340                 Idx:    uint(idx),
341                 Pong:   pong,
342         })
343 }
344
345 func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
346         p := topicNodes{Echo: queryHash}
347         var sent bool
348         for _, result := range nodes {
349                 if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
350                         p.Nodes = append(p.Nodes, nodeToRPC(result))
351                 }
352                 if len(p.Nodes) == maxTopicNodes {
353                         t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
354                         p.Nodes = p.Nodes[:0]
355                         sent = true
356                 }
357         }
358         if !sent || len(p.Nodes) > 0 {
359                 t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
360         }
361 }
362
363 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
364         //fmt.Println("sendPacket", nodeEvent(ptype), toaddr.String(), toid.String())
365         packet, hash, err := encodePacket(t.priv, ptype, req)
366         if err != nil {
367                 //fmt.Println(err)
368                 return hash, err
369         }
370         log.Debug(fmt.Sprintf(">>> %v to %x@%v", nodeEvent(ptype), toid[:8], toaddr))
371         if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
372                 log.Info(fmt.Sprint("UDP send failed:", err))
373         }
374         return hash, err
375 }
376
377 // zeroed padding space for encodePacket.
378 var headSpace = make([]byte, headSize)
379
380 func encodePacket(priv *crypto.PrivKeyEd25519, ptype byte, req interface{}) (p, hash []byte, err error) {
381         b := new(bytes.Buffer)
382         b.Write(headSpace)
383         b.WriteByte(ptype)
384         var size int
385         wire.WriteJSON(req, b, &size, &err)
386         if err != nil {
387                 log.Error(fmt.Sprint("error encoding packet:", err))
388                 return nil, nil, err
389         }
390         packet := b.Bytes()
391         nodeID := priv.PubKey().Unwrap().(crypto.PubKeyEd25519)
392         sig := priv.Sign(common.BytesToHash(packet[headSize:]).Bytes())
393         copy(packet, versionPrefix)
394         copy(packet[versionPrefixSize:], nodeID[:])
395         copy(packet[versionPrefixSize+nodeIDSize:], sig.Bytes())
396
397         hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
398         return packet, hash, nil
399 }
400
401 // readLoop runs in its own goroutine. it injects ingress UDP packets
402 // into the network loop.
403 func (t *udp) readLoop() {
404         defer t.conn.Close()
405         // Discovery packets are defined to be no larger than 1280 bytes.
406         // Packets larger than this size will be cut at the end and treated
407         // as invalid because their hash won't match.
408         buf := make([]byte, 1280)
409         for {
410                 nbytes, from, err := t.conn.ReadFromUDP(buf)
411                 if netutil.IsTemporaryError(err) {
412                         // Ignore temporary read errors.
413                         log.Debug(fmt.Sprintf("Temporary read error: %v", err))
414                         continue
415                 } else if err != nil {
416                         // Shut down the loop for permament errors.
417                         log.Debug(fmt.Sprintf("Read error: %v", err))
418                         return
419                 }
420                 t.handlePacket(from, buf[:nbytes])
421         }
422 }
423
424 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
425         pkt := ingressPacket{remoteAddr: from}
426         if err := decodePacket(buf, &pkt); err != nil {
427                 log.Debug(fmt.Sprintf("Bad packet from %v: %v", from, err))
428                 //fmt.Println("bad packet", err)
429                 return err
430         }
431         t.net.reqReadPacket(pkt)
432         return nil
433 }
434
435 func decodePacket(buffer []byte, pkt *ingressPacket) error {
436         if len(buffer) < headSize+1 {
437                 return errPacketTooSmall
438         }
439         buf := make([]byte, len(buffer))
440         copy(buf, buffer)
441         prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
442         if !bytes.Equal(prefix, versionPrefix) {
443                 return errBadPrefix
444         }
445         pkt.rawData = buf
446         pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
447         pkt.remoteID = ByteID(fromID)
448         switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
449         case pingPacket:
450                 pkt.data = new(ping)
451         case pongPacket:
452                 pkt.data = new(pong)
453         case findnodePacket:
454                 pkt.data = new(findnode)
455         case neighborsPacket:
456                 pkt.data = new(neighbors)
457         case findnodeHashPacket:
458                 pkt.data = new(findnodeHash)
459         case topicRegisterPacket:
460                 pkt.data = new(topicRegister)
461         case topicQueryPacket:
462                 pkt.data = new(topicQuery)
463         case topicNodesPacket:
464                 pkt.data = new(topicNodes)
465         default:
466                 return fmt.Errorf("unknown packet type: %d", sigdata[0])
467         }
468         var err error
469         wire.ReadJSON(pkt.data, sigdata[1:], &err)
470         if err != nil {
471                 log.Error("wire readjson err:", err)
472         }
473
474         return err
475 }