OSDN Git Service

Feat(ed25519): replace with crypto/ed25519 (#1907)
[bytom/bytom.git] / p2p / discover / dht / udp.go
1 package dht
2
3 import (
4         "bytes"
5         "crypto/ecdsa"
6         "crypto/ed25519"
7         "encoding/hex"
8         "errors"
9         "fmt"
10         "net"
11         "path"
12         "strconv"
13         "time"
14
15         log "github.com/sirupsen/logrus"
16         "github.com/tendermint/go-wire"
17
18         "github.com/bytom/bytom/common"
19         cfg "github.com/bytom/bytom/config"
20         "github.com/bytom/bytom/crypto"
21         "github.com/bytom/bytom/p2p/netutil"
22         "github.com/bytom/bytom/version"
23 )
24
25 const (
26         Version   = 4
27         logModule = "discover"
28 )
29
30 // Errors
31 var (
32         errPacketTooSmall   = errors.New("too small")
33         errBadPrefix        = errors.New("bad prefix")
34         errExpired          = errors.New("expired")
35         errUnsolicitedReply = errors.New("unsolicited reply")
36         errUnknownNode      = errors.New("unknown node")
37         errTimeout          = errors.New("RPC timeout")
38         errClockWarp        = errors.New("reply deadline too far in the future")
39         errClosed           = errors.New("socket closed")
40 )
41
42 // Timeouts
43 const (
44         respTimeout = 1 * time.Second
45         queryDelay  = 1000 * time.Millisecond
46         expiration  = 20 * time.Second
47
48         ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
49         ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
50         driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
51 )
52
53 // ReadPacket is sent to the unhandled channel when it could not be processed
54 type ReadPacket struct {
55         Data []byte
56         Addr *net.UDPAddr
57 }
58
59 // Config holds Table-related settings.
60 type Config struct {
61         // These settings are required and configure the UDP listener:
62         PrivateKey *ecdsa.PrivateKey
63
64         // These settings are optional:
65         AnnounceAddr *net.UDPAddr // local address announced in the DHT
66         NodeDBPath   string       // if set, the node database is stored at this filesystem location
67         //NetRestrict  *netutil.Netlist  // network whitelist
68         Bootnodes []*Node           // list of bootstrap nodes
69         Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
70 }
71
72 // RPC request structures
73 type (
74         ping struct {
75                 Version    uint
76                 From, To   rpcEndpoint
77                 Expiration uint64
78
79                 // v5
80                 Topics []Topic
81
82                 // Ignore additional fields (for forward compatibility).
83                 Rest []byte
84         }
85
86         // pong is the reply to ping.
87         pong struct {
88                 // This field should mirror the UDP envelope address
89                 // of the ping packet, which provides a way to discover the
90                 // the external address (after NAT).
91                 To rpcEndpoint
92
93                 ReplyTok   []byte // This contains the hash of the ping packet.
94                 Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
95
96                 // v5
97                 TopicHash    common.Hash
98                 TicketSerial uint32
99                 WaitPeriods  []uint32
100
101                 // Ignore additional fields (for forward compatibility).
102                 Rest []byte
103         }
104
105         // findnode is a query for nodes close to the given target.
106         findnode struct {
107                 Target     NodeID // doesn't need to be an actual public key
108                 Expiration uint64
109                 // Ignore additional fields (for forward compatibility).
110                 Rest []byte
111         }
112
113         // findnode is a query for nodes close to the given target.
114         findnodeHash struct {
115                 Target     common.Hash
116                 Expiration uint64
117                 // Ignore additional fields (for forward compatibility).
118                 Rest []byte
119         }
120
121         // reply to findnode
122         neighbors struct {
123                 Nodes      []rpcNode
124                 Expiration uint64
125                 // Ignore additional fields (for forward compatibility).
126                 Rest []byte
127         }
128
129         topicRegister struct {
130                 Topics []Topic
131                 Idx    uint
132                 Pong   []byte
133         }
134
135         topicQuery struct {
136                 Topic      Topic
137                 Expiration uint64
138         }
139
140         // reply to topicQuery
141         topicNodes struct {
142                 Echo  common.Hash
143                 Nodes []rpcNode
144         }
145
146         rpcNode struct {
147                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
148                 UDP uint16 // for discovery protocol
149                 TCP uint16 // for RLPx protocol
150                 ID  NodeID
151         }
152
153         rpcEndpoint struct {
154                 IP  net.IP // len 4 for IPv4 or 16 for IPv6
155                 UDP uint16 // for discovery protocol
156                 TCP uint16 // for RLPx protocol
157         }
158 )
159
160 var (
161         versionPrefix     = []byte("bytom discovery")
162         versionPrefixSize = len(versionPrefix)
163         nodeIDSize        = 32
164         sigSize           = 520 / 8
165         headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
166 )
167
168 // Neighbors replies are sent across multiple packets to
169 // stay below the 1280 byte limit. We compute the maximum number
170 // of entries by stuffing a packet until it grows too large.
171 var maxNeighbors = func() int {
172         p := neighbors{Expiration: ^uint64(0)}
173         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
174         for n := 0; ; n++ {
175                 p.Nodes = append(p.Nodes, maxSizeNode)
176                 var size int
177                 var err error
178                 b := new(bytes.Buffer)
179                 wire.WriteJSON(p, b, &size, &err)
180                 if err != nil {
181                         // If this ever happens, it will be caught by the unit tests.
182                         panic("cannot encode: " + err.Error())
183                 }
184                 if headSize+size+1 >= 1280 {
185                         return n
186                 }
187         }
188 }()
189
190 var maxTopicNodes = func() int {
191         p := topicNodes{}
192         maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
193         for n := 0; ; n++ {
194                 p.Nodes = append(p.Nodes, maxSizeNode)
195                 var size int
196                 var err error
197                 b := new(bytes.Buffer)
198                 wire.WriteJSON(p, b, &size, &err)
199                 if err != nil {
200                         // If this ever happens, it will be caught by the unit tests.
201                         panic("cannot encode: " + err.Error())
202                 }
203                 if headSize+size+1 >= 1280 {
204                         return n
205                 }
206         }
207 }()
208
209 func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
210         ip := addr.IP.To4()
211         if ip == nil {
212                 ip = addr.IP.To16()
213         }
214         return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
215 }
216
217 func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
218         return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
219 }
220
221 func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
222         if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
223                 return nil, err
224         }
225         n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
226         err := n.validateComplete()
227         return n, err
228 }
229
230 func nodeToRPC(n *Node) rpcNode {
231         return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
232 }
233
234 type ingressPacket struct {
235         remoteID   NodeID
236         remoteAddr *net.UDPAddr
237         ev         nodeEvent
238         hash       []byte
239         data       interface{} // one of the RPC structs
240         rawData    []byte
241 }
242
243 type conn interface {
244         ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
245         WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
246         Close() error
247         LocalAddr() net.Addr
248 }
249
250 type netWork interface {
251         reqReadPacket(pkt ingressPacket)
252         selfIP() net.IP
253 }
254
255 // udp implements the RPC protocol.
256 type udp struct {
257         conn        conn
258         priv        ed25519.PrivateKey
259         ourEndpoint rpcEndpoint
260         //nat         nat.Interface
261         net netWork
262 }
263
264 func NewDiscover(config *cfg.Config, priv ed25519.PrivateKey, port uint16) (*Network, error) {
265         addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
266         if err != nil {
267                 return nil, err
268         }
269
270         conn, err := net.ListenUDP("udp", addr)
271         if err != nil {
272                 return nil, err
273         }
274
275         realaddr := conn.LocalAddr().(*net.UDPAddr)
276         ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil)
277         if err != nil {
278                 return nil, err
279         }
280         seeds, err := QueryDNSSeeds(net.LookupHost)
281         if err != nil {
282                 log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
283         }
284
285         codedSeeds := netutil.CheckAndSplitAddresses(config.P2P.Seeds)
286         seeds = append(seeds, codedSeeds...)
287         if len(seeds) == 0 {
288                 return ntab, nil
289         }
290
291         var nodes []*Node
292         for _, seed := range seeds {
293                 version.Status.AddSeed(seed)
294                 url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
295                 nodes = append(nodes, MustParseNode(url))
296         }
297
298         if err = ntab.SetFallbackNodes(nodes); err != nil {
299                 return nil, err
300         }
301         return ntab, nil
302 }
303
304 // ListenUDP returns a new table that listens for UDP packets on laddr.
305 func ListenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
306         transport, err := listenUDP(priv, conn, realaddr)
307         if err != nil {
308                 return nil, err
309         }
310
311         net, err := newNetwork(transport, priv.Public(), nodeDBPath, netrestrict)
312         if err != nil {
313                 return nil, err
314         }
315         log.WithFields(log.Fields{"module": logModule, "net": net.tab.self}).Info("UDP listener up v5")
316         transport.net = net
317         go transport.readLoop()
318         return net, nil
319 }
320
321 func listenUDP(priv ed25519.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
322         return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
323 }
324
325 func (t *udp) localAddr() *net.UDPAddr {
326         return t.conn.LocalAddr().(*net.UDPAddr)
327 }
328
329 func (t *udp) Close() {
330         t.conn.Close()
331 }
332
333 func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
334         hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
335         return hash
336 }
337
338 func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
339         hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
340                 Version:    Version,
341                 From:       t.ourEndpoint,
342                 To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
343                 Expiration: uint64(time.Now().Add(expiration).Unix()),
344                 Topics:     topics,
345         })
346         return hash
347 }
348
349 func (t *udp) sendFindnode(remote *Node, target NodeID) {
350         t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
351                 Target:     target,
352                 Expiration: uint64(time.Now().Add(expiration).Unix()),
353         })
354 }
355
356 func (t *udp) sendNeighbours(remote *Node, results []*Node) {
357         // Send neighbors in chunks with at most maxNeighbors per packet
358         // to stay below the 1280 byte limit.
359         p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
360         for i, result := range results {
361                 p.Nodes = append(p.Nodes, nodeToRPC(result))
362                 if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
363                         t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
364                         p.Nodes = p.Nodes[:0]
365                 }
366         }
367 }
368
369 func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
370         t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
371                 Target:     common.Hash(target),
372                 Expiration: uint64(time.Now().Add(expiration).Unix()),
373         })
374 }
375
376 func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
377         t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
378                 Topics: topics,
379                 Idx:    uint(idx),
380                 Pong:   pong,
381         })
382 }
383
384 func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
385         p := topicNodes{Echo: queryHash}
386         var sent bool
387         for _, result := range nodes {
388                 if result.IP.Equal(t.net.selfIP()) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
389                         p.Nodes = append(p.Nodes, nodeToRPC(result))
390                 }
391                 if len(p.Nodes) == maxTopicNodes {
392                         t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
393                         p.Nodes = p.Nodes[:0]
394                         sent = true
395                 }
396         }
397         if !sent || len(p.Nodes) > 0 {
398                 t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
399         }
400 }
401
402 func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
403         packet, hash, err := encodePacket(t.priv, ptype, req)
404         if err != nil {
405                 return hash, err
406         }
407         log.WithFields(log.Fields{"module": logModule, "event": nodeEvent(ptype), "to id": hex.EncodeToString(toid[:8]), "to addr": toaddr}).Debug("send packet")
408         if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
409                 log.WithFields(log.Fields{"module": logModule, "error": err}).Info(fmt.Sprint("UDP send failed"))
410         }
411         return hash, err
412 }
413
414 // zeroed padding space for encodePacket.
415 var headSpace = make([]byte, headSize)
416
417 func encodePacket(priv ed25519.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
418         b := new(bytes.Buffer)
419         b.Write(headSpace)
420         b.WriteByte(ptype)
421         var size int
422         wire.WriteJSON(req, b, &size, &err)
423         if err != nil {
424                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("error encoding packet")
425                 return nil, nil, err
426         }
427         packet := b.Bytes()
428         nodeID := priv.Public()
429         sig := ed25519.Sign(priv, common.BytesToHash(packet[headSize:]).Bytes())
430         copy(packet, versionPrefix)
431         copy(packet[versionPrefixSize:], nodeID.([]byte)[:])
432         copy(packet[versionPrefixSize+nodeIDSize:], sig)
433
434         hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
435         return packet, hash, nil
436 }
437
438 // readLoop runs in its own goroutine. it injects ingress UDP packets
439 // into the network loop.
440 func (t *udp) readLoop() {
441         defer t.conn.Close()
442         // Discovery packets are defined to be no larger than 1280 bytes.
443         // Packets larger than this size will be cut at the end and treated
444         // as invalid because their hash won't match.
445         buf := make([]byte, 1280)
446         for {
447                 nbytes, from, err := t.conn.ReadFromUDP(buf)
448                 if netutil.IsTemporaryError(err) {
449                         // Ignore temporary read errors.
450                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Temporary read error")
451                         continue
452                 } else if err != nil {
453                         // Shut down the loop for permament errors.
454                         log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
455                         return
456                 }
457                 t.handlePacket(from, buf[:nbytes])
458         }
459 }
460
461 func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
462         pkt := ingressPacket{remoteAddr: from}
463         if err := decodePacket(buf, &pkt); err != nil {
464                 log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
465                 return err
466         }
467         t.net.reqReadPacket(pkt)
468         return nil
469 }
470
471 func decodePacket(buffer []byte, pkt *ingressPacket) error {
472         if len(buffer) < headSize+1 {
473                 return errPacketTooSmall
474         }
475         buf := make([]byte, len(buffer))
476         copy(buf, buffer)
477         prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
478         if !bytes.Equal(prefix, versionPrefix) {
479                 return errBadPrefix
480         }
481         pkt.rawData = buf
482         pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
483         pkt.remoteID = ByteID(fromID)
484         switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
485         case pingPacket:
486                 pkt.data = new(ping)
487         case pongPacket:
488                 pkt.data = new(pong)
489         case findnodePacket:
490                 pkt.data = new(findnode)
491         case neighborsPacket:
492                 pkt.data = new(neighbors)
493         case findnodeHashPacket:
494                 pkt.data = new(findnodeHash)
495         case topicRegisterPacket:
496                 pkt.data = new(topicRegister)
497         case topicQueryPacket:
498                 pkt.data = new(topicQuery)
499         case topicNodesPacket:
500                 pkt.data = new(topicNodes)
501         default:
502                 return fmt.Errorf("unknown packet type: %d", sigdata[0])
503         }
504         var err error
505         wire.ReadJSON(pkt.data, sigdata[1:], &err)
506         if err != nil {
507                 log.WithFields(log.Fields{"module": logModule, "error": err}).Error("wire readjson err")
508         }
509
510         return err
511 }