OSDN Git Service

c8d73b70eb427c2639f1171dc71253fda940b92b
[bytom/vapor.git] / p2p / discover / dht / udp.go
1 package dht
2
3 import (
4         "bytes"
5         "crypto/ecdsa"
6         "encoding/hex"
7         "errors"
8         "fmt"
9         "net"
10         "path"
11         "strconv"
12         "time"
13
14         log "github.com/sirupsen/logrus"
15         "github.com/tendermint/go-wire"
16
17         "github.com/vapor/common"
18         cfg "github.com/vapor/config"
19         "github.com/vapor/crypto"
20         "github.com/vapor/p2p/netutil"
21         "github.com/vapor/p2p/signlib"
22         "github.com/vapor/version"
23 )
24
25 const (
26         //Version dht discover protocol version
27         Version   = 5
28         logModule = "discover"
29 )
30
31 // Errors
32 var (
33         errPacketTooSmall = errors.New("too small")
34         errPrefixMismatch = errors.New("prefix mismatch")
35         errNetIDMismatch  = errors.New("network id mismatch")
36         errPacketType     = errors.New("unknown packet type")
37 )
38
39 // Timeouts
40 const (
41         respTimeout = 1 * time.Second
42         expiration  = 20 * time.Second
43 )
44
45 // ReadPacket is sent to the unhandled channel when it could not be processed
46 type ReadPacket struct {
47         Data []byte
48         Addr *net.UDPAddr
49 }
50
51 // Config holds Table-related settings.
52 type Config struct {
53         // These settings are required and configure the UDP listener:
54         PrivateKey *ecdsa.PrivateKey
55
56         // These settings are optional:
57         AnnounceAddr *net.UDPAddr // local address announced in the DHT
58         NodeDBPath   string       // if set, the node database is stored at this filesystem location
59         //NetRestrict  *netutil.Netlist  // network whitelist
60         Bootnodes []*Node           // list of bootstrap nodes
61         Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
62 }
63
64 // RPC request structures
65 type (
66         ping struct {
67                 Version    uint
68                 From, To   rpcEndpoint
69                 Expiration uint64
70
71                 // v5
72                 Topics []Topic
73
74                 // Ignore additional fields (for forward compatibility).
75                 Rest []byte
76         }
77
78         // pong is the reply to ping.
79         pong struct {
80                 // This field should mirror the UDP envelope address
81                 // of the ping packet, which provides a way to discover the
82                 // the external address (after NAT).
83                 To rpcEndpoint
84
85                 ReplyTok   []byte // This contains the hash of the ping packet.
86                 Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
87
88                 // v5
89                 TopicHash    common.Hash
90                 TicketSerial uint32
91                 WaitPeriods  []uint32
92
93                 // Ignore additional fields (for forward compatibility).
94                 Rest []byte
95         }
96
97         // findnode is a query for nodes close to the given target.
98         findnode struct {
99                 Target     NodeID // doesn't need to be an actual public key
100                 Expiration uint64
101                 // Ignore additional fields (for forward compatibility).
102                 Rest []byte
103         }
104
105         // findnode is a query for nodes close to the given target.
106         findnodeHash struct {
107                 Target     common.Hash
108                 Expiration uint64
109                 // Ignore additional fields (for forward compatibility).
110                 Rest []byte
111         }
112
113         // reply to findnode
114         neighbors struct {
115                 Nodes      []rpcNode
116                 Expiration uint64
117                 // Ignore additional fields (for forward compatibility).
118                 Rest []byte
119         }
120
121         topicRegister struct {
122                 Topics []Topic
123                 Idx    uint
124                 Pong   []byte
125         }
126
127         topicQuery struct {
128                 Topic      Topic
129                 Expiration uint64
130         }
131
132         // reply to topicQuery
133         topicNodes struct {
134                 Echo  common.Hash
135                 Nodes []rpcNode
136         }
137
138         rpcNode struct {
139                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
140                 UDP uint16 // for discovery protocol
141                 TCP uint16 // for RLPx protocol
142                 ID  NodeID
143         }
144
145         rpcEndpoint struct {
146                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
147                 UDP uint16 // for discovery protocol
148                 TCP uint16 // for RLPx protocol
149         }
150 )
151
152 var (
153         netIDSize  = 8
154         nodeIDSize = 32
155         sigSize    = 520 / 8
156         headSize   = netIDSize + nodeIDSize + sigSize // space of packet frame data
157 )
158
159 // Neighbors replies are sent across multiple packets to
160 // stay below the 1280 byte limit. We compute the maximum number
161 // of entries by stuffing a packet until it grows too large.
162 var maxNeighbors = func() int {
163         p := neighbors{Expiration: ^uint64(0)}
164         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
165         for n := 0; ; n++ {
166                 p.Nodes = append(p.Nodes, maxSizeNode)
167                 var size int
168                 var err error
169                 b := new(bytes.Buffer)
170                 wire.WriteJSON(p, b, &size, &err)
171                 if err != nil {
172                         // If this ever happens, it will be caught by the unit tests.
173                         panic("cannot encode: " + err.Error())
174                 }
175                 if headSize+size+1 >= 1280 {
176                         return n
177                 }
178         }
179 }()
180
181 var maxTopicNodes = func() int {
182         p := topicNodes{}
183         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
184         for n := 0; ; n++ {
185                 p.Nodes = append(p.Nodes, maxSizeNode)
186                 var size int
187                 var err error
188                 b := new(bytes.Buffer)
189                 wire.WriteJSON(p, b, &size, &err)
190                 if err != nil {
191                         // If this ever happens, it will be caught by the unit tests.
192                         panic("cannot encode: " + err.Error())
193                 }
194                 if headSize+size+1 >= 1280 {
195                         return n
196                 }
197         }
198 }()
199
200 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
201         ip := addr.IP.To4()
202         if ip == nil {
203                 ip = addr.IP.To16()
204         }
205         return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
206 }
207
208 func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
209         return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
210 }
211
212 func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
213         if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
214                 return nil, err
215         }
216         n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
217         err := n.validateComplete()
218         return n, err
219 }
220
221 func nodeToRPC(n *Node) rpcNode {
222         return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
223 }
224
225 type ingressPacket struct {
226         remoteID   NodeID
227         remoteAddr *net.UDPAddr
228         ev         nodeEvent
229         hash       []byte
230         data       interface{} // one of the RPC structs
231         rawData    []byte
232 }
233
234 type conn interface {
235         ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
236         WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
237         Close() error
238         LocalAddr() net.Addr
239 }
240
241 type netWork interface {
242         reqReadPacket(pkt ingressPacket)
243         selfIP() net.IP
244 }
245
246 // udp implements the RPC protocol.
247 type udp struct {
248         conn conn
249         priv signlib.PrivKey
250         //netID used to isolate subnets
251         netID       uint64
252         ourEndpoint rpcEndpoint
253         net         netWork
254 }
255
256 //NewDiscover create new dht discover
257 func NewDiscover(config *cfg.Config, privKey signlib.PrivKey, port uint16, netID uint64) (*Network, error) {
258         addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
259         if err != nil {
260                 return nil, err
261         }
262
263         conn, err := net.ListenUDP("udp", addr)
264         if err != nil {
265                 return nil, err
266         }
267
268         realaddr := conn.LocalAddr().(*net.UDPAddr)
269         ntab, err := ListenUDP(privKey, conn, realaddr, path.Join(config.DBDir(), "discover"), nil, netID)
270         if err != nil {
271                 return nil, err
272         }
273         seeds, err := QueryDNSSeeds(net.LookupHost)
274         if err != nil {
275                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
276         }
277
278         codedSeeds := netutil.CheckAndSplitAddresses(config.P2P.Seeds)
279         seeds = append(seeds, codedSeeds...)
280         if len(seeds) == 0 {
281                 return ntab, nil
282         }
283
284         var nodes []*Node
285         for _, seed := range seeds {
286                 version.Status.AddSeed(seed)
287                 url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
288                 nodes = append(nodes, MustParseNode(url))
289         }
290
291         if err = ntab.SetFallbackNodes(nodes); err != nil {
292                 return nil, err
293         }
294         return ntab, nil
295 }
296
297 // ListenUDP returns a new table that listens for UDP packets on laddr.
298 func ListenUDP(privKey signlib.PrivKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist, netID uint64) (*Network, error) {
299         transport, err := listenUDP(privKey, conn, realaddr, netID)
300         if err != nil {
301                 return nil, err
302         }
303
304         net, err := newNetwork(transport, privKey.XPub(), nodeDBPath, netrestrict)
305         if err != nil {
306                 return nil, err
307         }
308         log.WithFields(log.Fields{"module": logModule, "net": net.tab.self}).Info("UDP listener up v5")
309         transport.net = net
310         go transport.readLoop()
311         return net, nil
312 }
313
314 func listenUDP(priv signlib.PrivKey, conn conn, realaddr *net.UDPAddr, netID uint64) (*udp, error) {
315         return &udp{conn: conn, priv: priv, netID: netID, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
316 }
317
318 func (t *udp) localAddr() *net.UDPAddr {
319         return t.conn.LocalAddr().(*net.UDPAddr)
320 }
321
322 func (t *udp) Close() {
323         t.conn.Close()
324 }
325
326 func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
327         hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
328         return hash
329 }
330
331 func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
332         hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
333                 Version:    Version,
334                 From:       t.ourEndpoint,
335                 To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
336                 Expiration: uint64(time.Now().Add(expiration).Unix()),
337                 Topics:     topics,
338         })
339         return hash
340 }
341
342 func (t *udp) sendFindnode(remote *Node, target NodeID) {
343         t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
344                 Target:     target,
345                 Expiration: uint64(time.Now().Add(expiration).Unix()),
346         })
347 }
348
349 func (t *udp) sendNeighbours(remote *Node, results []*Node) {
350         // Send neighbors in chunks with at most maxNeighbors per packet
351         // to stay below the 1280 byte limit.
352         p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
353         for i, result := range results {
354                 p.Nodes = append(p.Nodes, nodeToRPC(result))
355                 if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
356                         t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
357                         p.Nodes = p.Nodes[:0]
358                 }
359         }
360 }
361
362 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
363         t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
364                 Target:     target,
365                 Expiration: uint64(time.Now().Add(expiration).Unix()),
366         })
367 }
368
369 func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
370         t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
371                 Topics: topics,
372                 Idx:    uint(idx),
373                 Pong:   pong,
374         })
375 }
376
377 func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
378         p := topicNodes{Echo: queryHash}
379         var sent bool
380         for _, result := range nodes {
381                 if result.IP.Equal(t.net.selfIP()) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
382                         p.Nodes = append(p.Nodes, nodeToRPC(result))
383                 }
384                 if len(p.Nodes) == maxTopicNodes {
385                         t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
386                         p.Nodes = p.Nodes[:0]
387                         sent = true
388                 }
389         }
390         if !sent || len(p.Nodes) > 0 {
391                 t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
392         }
393 }
394
395 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
396         packet, hash, err := encodePacket(t.priv, ptype, req, t.netID)
397         if err != nil {
398                 return hash, err
399         }
400         log.WithFields(log.Fields{"module": logModule, "event": nodeEvent(ptype), "to id": hex.EncodeToString(toid[:8]), "to addr": toaddr}).Debug("send packet")
401         if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
402                 log.WithFields(log.Fields{"module": logModule, "error": err}).Info(fmt.Sprint("UDP send failed"))
403         }
404         return hash, err
405 }
406
407 // zeroed padding space for encodePacket.
408 var headSpace = make([]byte, headSize)
409
410 func encodePacket(privKey signlib.PrivKey, ptype byte, req interface{}, netID uint64) (p, hash []byte, err error) {
411         b := new(bytes.Buffer)
412         b.Write(headSpace)
413         b.WriteByte(ptype)
414         var size int
415         wire.WriteJSON(req, b, &size, &err)
416         if err != nil {
417                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("error encoding packet")
418                 return nil, nil, err
419         }
420         packet := b.Bytes()
421         nodeID := privKey.XPub()
422         sig := privKey.Sign(common.BytesToHash(packet[headSize:]).Bytes())
423         id := []byte(strconv.FormatUint(netID, 16))
424         copy(packet[:], id[:])
425         copy(packet[netIDSize:], nodeID[:nodeIDSize])
426         copy(packet[netIDSize+nodeIDSize:], sig)
427
428         hash = common.BytesToHash(packet[:]).Bytes()
429         return packet, hash, nil
430 }
431
432 // readLoop runs in its own goroutine. it injects ingress UDP packets
433 // into the network loop.
434 func (t *udp) readLoop() {
435         defer t.conn.Close()
436         // Discovery packets are defined to be no larger than 1280 bytes.
437         // Packets larger than this size will be cut at the end and treated
438         // as invalid because their hash won't match.
439         buf := make([]byte, 1280)
440         for {
441                 nbytes, from, err := t.conn.ReadFromUDP(buf)
442                 if netutil.IsTemporaryError(err) {
443                         // Ignore temporary read errors.
444                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Temporary read error")
445                         continue
446                 } else if err != nil {
447                         // Shut down the loop for permament errors.
448                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
449                         return
450                 }
451                 if err := t.handlePacket(from, buf[:nbytes]); err != nil {
452                         log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("handle packet err")
453                 }
454         }
455 }
456
457 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
458         pkt := ingressPacket{remoteAddr: from}
459         if err := decodePacket(buf, &pkt, t.netID); err != nil {
460                 return err
461         }
462         t.net.reqReadPacket(pkt)
463         return nil
464 }
465
466 func (t *udp) getNetID() uint64 {
467         return t.netID
468 }
469
470 func decodePacket(buffer []byte, pkt *ingressPacket, netID uint64) error {
471         if len(buffer) < headSize+1 {
472                 return errPacketTooSmall
473         }
474         buf := make([]byte, len(buffer))
475         copy(buf, buffer)
476         fromID, sigdata := buf[netIDSize:netIDSize+nodeIDSize], buf[headSize:]
477
478         if !bytes.Equal(buf[:netIDSize], []byte(strconv.FormatUint(netID, 16))[:netIDSize]) {
479                 return errNetIDMismatch
480         }
481
482         pkt.rawData = buf
483         pkt.hash = common.BytesToHash(buf[:]).Bytes()
484         pkt.remoteID = ByteID(fromID)
485         switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
486         case pingPacket:
487                 pkt.data = new(ping)
488         case pongPacket:
489                 pkt.data = new(pong)
490         case findnodePacket:
491                 pkt.data = new(findnode)
492         case neighborsPacket:
493                 pkt.data = new(neighbors)
494         case findnodeHashPacket:
495                 pkt.data = new(findnodeHash)
496         case topicRegisterPacket:
497                 pkt.data = new(topicRegister)
498         case topicQueryPacket:
499                 pkt.data = new(topicQuery)
500         case topicNodesPacket:
501                 pkt.data = new(topicNodes)
502         default:
503                 return errPacketType
504         }
505         var err error
506         wire.ReadJSON(pkt.data, sigdata[1:], &err)
507         if err != nil {
508                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("wire readjson err")
509         }
510
511         return err
512 }