OSDN Git Service

Feat(ed25519): replace with crypto/ed25519 (#1907)
[bytom/bytom.git] / p2p / discover / dht / net.go
1 package dht
2
3 import (
4         "bytes"
5         "crypto"
6         "encoding/hex"
7         "errors"
8         "fmt"
9         "net"
10         "time"
11
12         log "github.com/sirupsen/logrus"
13         "github.com/tendermint/go-wire"
14         "golang.org/x/crypto/sha3"
15
16         "github.com/bytom/bytom/common"
17         "github.com/bytom/bytom/p2p/netutil"
18 )
19
20 var (
21         errInvalidEvent = errors.New("invalid in current state")
22         errNoQuery      = errors.New("no pending query")
23         errWrongAddress = errors.New("unknown sender address")
24 )
25
26 const (
27         autoRefreshInterval   = 1 * time.Hour
28         bucketRefreshInterval = 1 * time.Minute
29         seedCount             = 30
30         seedMaxAge            = 5 * 24 * time.Hour
31         lowPort               = 1024
32 )
33
34 const (
35         printTestImgLogs = false
36 )
37
38 // Network manages the table and all protocol interaction.
39 type Network struct {
40         db          *nodeDB // database of known nodes
41         conn        transport
42         netrestrict *netutil.Netlist
43
44         closed           chan struct{}          // closed when loop is done
45         closeReq         chan struct{}          // 'request to close'
46         refreshReq       chan []*Node           // lookups ask for refresh on this channel
47         refreshResp      chan (<-chan struct{}) // ...and get the channel to block on from this one
48         read             chan ingressPacket     // ingress packets arrive here
49         timeout          chan timeoutEvent
50         queryReq         chan *findnodeQuery // lookups submit findnode queries on this channel
51         tableOpReq       chan func()
52         tableOpResp      chan struct{}
53         topicRegisterReq chan topicRegisterReq
54         topicSearchReq   chan topicSearchReq
55
56         // State of the main loop.
57         tab           *Table
58         topictab      *topicTable
59         ticketStore   *ticketStore
60         nursery       []*Node
61         nodes         map[NodeID]*Node // tracks active nodes with state != known
62         timeoutTimers map[timeoutEvent]*time.Timer
63
64         // Revalidation queues.
65         // Nodes put on these queues will be pinged eventually.
66         slowRevalidateQueue []*Node
67         fastRevalidateQueue []*Node
68
69         // Buffers for state transition.
70         sendBuf []*ingressPacket
71 }
72
73 // transport is implemented by the UDP transport.
74 // it is an interface so we can test without opening lots of UDP
75 // sockets and without generating a private key.
76 type transport interface {
77         sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte)
78         sendNeighbours(remote *Node, nodes []*Node)
79         sendFindnodeHash(remote *Node, target common.Hash)
80         sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte)
81         sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
82
83         send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte)
84
85         localAddr() *net.UDPAddr
86         Close()
87 }
88
89 type findnodeQuery struct {
90         remote   *Node
91         target   common.Hash
92         reply    chan<- []*Node
93         nresults int // counter for received nodes
94 }
95
96 type topicRegisterReq struct {
97         add   bool
98         topic Topic
99 }
100
101 type topicSearchReq struct {
102         topic  Topic
103         found  chan<- *Node
104         lookup chan<- bool
105         delay  time.Duration
106 }
107
108 type topicSearchResult struct {
109         target lookupInfo
110         nodes  []*Node
111 }
112
113 type timeoutEvent struct {
114         ev   nodeEvent
115         node *Node
116 }
117
118 func newNetwork(conn transport, ourPubkey crypto.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
119         var ourID NodeID
120         copy(ourID[:], ourPubkey.([]byte)[:nodeIDBits])
121
122         var db *nodeDB
123         if dbPath != "<no database>" {
124                 var err error
125                 if db, err = newNodeDB(dbPath, Version, ourID); err != nil {
126                         return nil, err
127                 }
128         }
129
130         tab := newTable(ourID, conn.localAddr())
131         net := &Network{
132                 db:               db,
133                 conn:             conn,
134                 netrestrict:      netrestrict,
135                 tab:              tab,
136                 topictab:         newTopicTable(db, tab.self),
137                 ticketStore:      newTicketStore(),
138                 refreshReq:       make(chan []*Node),
139                 refreshResp:      make(chan (<-chan struct{})),
140                 closed:           make(chan struct{}),
141                 closeReq:         make(chan struct{}),
142                 read:             make(chan ingressPacket, 100),
143                 timeout:          make(chan timeoutEvent),
144                 timeoutTimers:    make(map[timeoutEvent]*time.Timer),
145                 tableOpReq:       make(chan func()),
146                 tableOpResp:      make(chan struct{}),
147                 queryReq:         make(chan *findnodeQuery),
148                 topicRegisterReq: make(chan topicRegisterReq),
149                 topicSearchReq:   make(chan topicSearchReq),
150                 nodes:            make(map[NodeID]*Node),
151         }
152         go net.loop()
153         return net, nil
154 }
155
156 // Close terminates the network listener and flushes the node database.
157 func (net *Network) Close() {
158         net.conn.Close()
159         select {
160         case <-net.closed:
161         case net.closeReq <- struct{}{}:
162                 <-net.closed
163         }
164 }
165
166 // Self returns the local node.
167 // The returned node should not be modified by the caller.
168 func (net *Network) Self() *Node {
169         return net.tab.self
170 }
171
172 func (net *Network) selfIP() net.IP {
173         return net.tab.self.IP
174 }
175
176 // ReadRandomNodes fills the given slice with random nodes from the
177 // table. It will not write the same node more than once. The nodes in
178 // the slice are copies and can be modified by the caller.
179 func (net *Network) ReadRandomNodes(buf []*Node) (n int) {
180         net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) })
181         return n
182 }
183
184 // SetFallbackNodes sets the initial points of contact. These nodes
185 // are used to connect to the network if the table is empty and there
186 // are no known nodes in the database.
187 func (net *Network) SetFallbackNodes(nodes []*Node) error {
188         nursery := make([]*Node, 0, len(nodes))
189         for _, n := range nodes {
190                 if err := n.validateComplete(); err != nil {
191                         return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
192                 }
193                 // Recompute cpy.sha because the node might not have been
194                 // created by NewNode or ParseNode.
195                 cpy := *n
196                 cpy.sha = common.BytesToHash(n.ID[:])
197                 nursery = append(nursery, &cpy)
198         }
199         net.reqRefresh(nursery)
200         return nil
201 }
202
203 // Resolve searches for a specific node with the given ID.
204 // It returns nil if the node could not be found.
205 func (net *Network) Resolve(targetID NodeID) *Node {
206         result := net.lookup(common.BytesToHash(targetID[:]), true)
207         for _, n := range result {
208                 if n.ID == targetID {
209                         return n
210                 }
211         }
212         return nil
213 }
214
215 // Lookup performs a network search for nodes close
216 // to the given target. It approaches the target by querying
217 // nodes that are closer to it on each iteration.
218 // The given target does not need to be an actual node
219 // identifier.
220 //
221 // The local node may be included in the result.
222 func (net *Network) Lookup(targetID NodeID) []*Node {
223         return net.lookup(common.BytesToHash(targetID[:]), false)
224 }
225
226 func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node {
227         var (
228                 asked          = make(map[NodeID]bool)
229                 seen           = make(map[NodeID]bool)
230                 reply          = make(chan []*Node, alpha)
231                 result         = nodesByDistance{target: target}
232                 pendingQueries = 0
233         )
234         // Get initial answers from the local node.
235         result.push(net.tab.self, bucketSize)
236         for {
237                 // Ask the Î± closest nodes that we haven't asked yet.
238                 for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
239                         n := result.entries[i]
240                         if !asked[n.ID] {
241                                 asked[n.ID] = true
242                                 pendingQueries++
243                                 net.reqQueryFindnode(n, target, reply)
244                         }
245                 }
246                 if pendingQueries == 0 {
247                         // We have asked all closest nodes, stop the search.
248                         break
249                 }
250                 // Wait for the next reply.
251                 select {
252                 case nodes := <-reply:
253                         for _, n := range nodes {
254                                 if n != nil && !seen[n.ID] {
255                                         seen[n.ID] = true
256                                         result.push(n, bucketSize)
257                                         if stopOnMatch && n.sha == target {
258                                                 return result.entries
259                                         }
260                                 }
261                         }
262                         pendingQueries--
263                 case <-time.After(respTimeout):
264                         // forget all pending requests, start new ones
265                         pendingQueries = 0
266                         reply = make(chan []*Node, alpha)
267                 }
268         }
269         return result.entries
270 }
271
272 func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) {
273         select {
274         case net.topicRegisterReq <- topicRegisterReq{true, topic}:
275         case <-net.closed:
276                 return
277         }
278         select {
279         case <-net.closed:
280         case <-stop:
281                 select {
282                 case net.topicRegisterReq <- topicRegisterReq{false, topic}:
283                 case <-net.closed:
284                 }
285         }
286 }
287
288 func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) {
289         for {
290                 select {
291                 case <-net.closed:
292                         return
293                 case delay, ok := <-setPeriod:
294                         select {
295                         case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}:
296                         case <-net.closed:
297                                 return
298                         }
299                         if !ok {
300                                 return
301                         }
302                 }
303         }
304 }
305
306 func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} {
307         select {
308         case net.refreshReq <- nursery:
309                 return <-net.refreshResp
310         case <-net.closed:
311                 return net.closed
312         }
313 }
314
315 func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool {
316         q := &findnodeQuery{remote: n, target: target, reply: reply}
317         select {
318         case net.queryReq <- q:
319                 return true
320         case <-net.closed:
321                 return false
322         }
323 }
324
325 func (net *Network) reqReadPacket(pkt ingressPacket) {
326         select {
327         case net.read <- pkt:
328         case <-net.closed:
329         }
330 }
331
332 func (net *Network) reqTableOp(f func()) (called bool) {
333         select {
334         case net.tableOpReq <- f:
335                 <-net.tableOpResp
336                 return true
337         case <-net.closed:
338                 return false
339         }
340 }
341
342 // TODO: external address handling.
343
344 type topicSearchInfo struct {
345         lookupChn chan<- bool
346         period    time.Duration
347 }
348
349 const maxSearchCount = 5
350
351 func (net *Network) loop() {
352         var (
353                 refreshTimer       = time.NewTicker(autoRefreshInterval)
354                 bucketRefreshTimer = time.NewTimer(bucketRefreshInterval)
355                 refreshDone        chan struct{} // closed when the 'refresh' lookup has ended
356         )
357
358         // Tracking the next ticket to register.
359         var (
360                 nextTicket        *ticketRef
361                 nextRegisterTimer *time.Timer
362                 nextRegisterTime  <-chan time.Time
363         )
364         defer func() {
365                 if nextRegisterTimer != nil {
366                         nextRegisterTimer.Stop()
367                 }
368                 refreshTimer.Stop()
369                 bucketRefreshTimer.Stop()
370         }()
371         resetNextTicket := func() {
372                 ticket, timeout := net.ticketStore.nextFilteredTicket()
373                 if nextTicket != ticket {
374                         nextTicket = ticket
375                         if nextRegisterTimer != nil {
376                                 nextRegisterTimer.Stop()
377                                 nextRegisterTime = nil
378                         }
379                         if ticket != nil {
380                                 nextRegisterTimer = time.NewTimer(timeout)
381                                 nextRegisterTime = nextRegisterTimer.C
382                         }
383                 }
384         }
385
386         // Tracking registration and search lookups.
387         var (
388                 topicRegisterLookupTarget lookupInfo
389                 topicRegisterLookupDone   chan []*Node
390                 topicRegisterLookupTick   = time.NewTimer(0)
391                 searchReqWhenRefreshDone  []topicSearchReq
392                 searchInfo                = make(map[Topic]topicSearchInfo)
393                 activeSearchCount         int
394         )
395         topicSearchLookupDone := make(chan topicSearchResult, 100)
396         topicSearch := make(chan Topic, 100)
397         <-topicRegisterLookupTick.C
398
399         statsDump := time.NewTicker(10 * time.Second)
400         defer statsDump.Stop()
401
402 loop:
403         for {
404                 resetNextTicket()
405
406                 select {
407                 case <-net.closeReq:
408                         log.WithFields(log.Fields{"module": logModule}).Debug("close request")
409                         break loop
410
411                 // Ingress packet handling.
412                 case pkt := <-net.read:
413                         log.WithFields(log.Fields{"module": logModule}).Debug("read from net")
414                         n := net.internNode(&pkt)
415                         prestate := n.state
416                         status := "ok"
417                         if err := net.handle(n, pkt.ev, &pkt); err != nil {
418                                 status = err.Error()
419                         }
420                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": pkt.ev, "remote id": hex.EncodeToString(pkt.remoteID[:8]), "remote addr": pkt.remoteAddr, "pre state": prestate, "node state": n.state, "status": status}).Debug("handle ingress msg")
421
422                         // TODO: persist state if n.state goes >= known, delete if it goes <= known
423
424                 // State transition timeouts.
425                 case timeout := <-net.timeout:
426                         log.WithFields(log.Fields{"module": logModule}).Debug("net timeout")
427                         if net.timeoutTimers[timeout] == nil {
428                                 // Stale timer (was aborted).
429                                 continue
430                         }
431                         delete(net.timeoutTimers, timeout)
432                         prestate := timeout.node.state
433                         status := "ok"
434                         if err := net.handle(timeout.node, timeout.ev, nil); err != nil {
435                                 status = err.Error()
436                         }
437                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": timeout.ev, "node id": hex.EncodeToString(timeout.node.ID[:8]), "node addr": timeout.node.addr(), "pre state": prestate, "node state": timeout.node.state, "status": status}).Debug("handle timeout")
438
439                 // Querying.
440                 case q := <-net.queryReq:
441                         log.WithFields(log.Fields{"module": logModule}).Debug("net query request")
442                         if !q.start(net) {
443                                 q.remote.deferQuery(q)
444                         }
445
446                 // Interacting with the table.
447                 case f := <-net.tableOpReq:
448                         log.WithFields(log.Fields{"module": logModule}).Debug("net table operate request")
449                         f()
450                         net.tableOpResp <- struct{}{}
451
452                 // Topic registration stuff.
453                 case req := <-net.topicRegisterReq:
454                         log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic register request")
455                         if !req.add {
456                                 net.ticketStore.removeRegisterTopic(req.topic)
457                                 continue
458                         }
459                         net.ticketStore.addTopic(req.topic, true)
460                         // If we're currently waiting idle (nothing to look up), give the ticket store a
461                         // chance to start it sooner. This should speed up convergence of the radius
462                         // determination for new topics.
463                         // if topicRegisterLookupDone == nil {
464                         if topicRegisterLookupTarget.target == (common.Hash{}) {
465                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("topic register lookup target null")
466                                 if topicRegisterLookupTick.Stop() {
467                                         <-topicRegisterLookupTick.C
468                                 }
469                                 target, delay := net.ticketStore.nextRegisterLookup()
470                                 topicRegisterLookupTarget = target
471                                 topicRegisterLookupTick.Reset(delay)
472                         }
473
474                 case nodes := <-topicRegisterLookupDone:
475                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup done")
476                         net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte {
477                                 net.ping(n, n.addr())
478                                 return n.pingEcho
479                         })
480                         target, delay := net.ticketStore.nextRegisterLookup()
481                         topicRegisterLookupTarget = target
482                         topicRegisterLookupTick.Reset(delay)
483                         topicRegisterLookupDone = nil
484
485                 case <-topicRegisterLookupTick.C:
486                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup tick")
487                         if (topicRegisterLookupTarget.target == common.Hash{}) {
488                                 target, delay := net.ticketStore.nextRegisterLookup()
489                                 topicRegisterLookupTarget = target
490                                 topicRegisterLookupTick.Reset(delay)
491                                 topicRegisterLookupDone = nil
492                         } else {
493                                 topicRegisterLookupDone = make(chan []*Node)
494                                 target := topicRegisterLookupTarget.target
495                                 go func() { topicRegisterLookupDone <- net.lookup(target, false) }()
496                         }
497
498                 case <-nextRegisterTime:
499                         log.WithFields(log.Fields{"module": logModule}).Debug("next register time")
500                         net.ticketStore.ticketRegistered(*nextTicket)
501                         net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
502
503                 case req := <-net.topicSearchReq:
504                         if refreshDone == nil {
505                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic rearch req")
506                                 info, ok := searchInfo[req.topic]
507                                 if ok {
508                                         if req.delay == time.Duration(0) {
509                                                 delete(searchInfo, req.topic)
510                                                 net.ticketStore.removeSearchTopic(req.topic)
511                                         } else {
512                                                 info.period = req.delay
513                                                 searchInfo[req.topic] = info
514                                         }
515                                         continue
516                                 }
517                                 if req.delay != time.Duration(0) {
518                                         var info topicSearchInfo
519                                         info.period = req.delay
520                                         info.lookupChn = req.lookup
521                                         searchInfo[req.topic] = info
522                                         net.ticketStore.addSearchTopic(req.topic, req.found)
523                                         topicSearch <- req.topic
524                                 }
525                         } else {
526                                 searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req)
527                         }
528
529                 case topic := <-topicSearch:
530                         if activeSearchCount < maxSearchCount {
531                                 activeSearchCount++
532                                 target := net.ticketStore.nextSearchLookup(topic)
533                                 go func() {
534                                         nodes := net.lookup(target.target, false)
535                                         topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes}
536                                 }()
537                         }
538                         period := searchInfo[topic].period
539                         if period != time.Duration(0) {
540                                 go func() {
541                                         time.Sleep(period)
542                                         topicSearch <- topic
543                                 }()
544                         }
545
546                 case res := <-topicSearchLookupDone:
547                         activeSearchCount--
548                         if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil {
549                                 lookupChn <- net.ticketStore.radius[res.target.topic].converged
550                         }
551                         net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte {
552                                 if n.state != nil && n.state.canQuery {
553                                         return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration
554                                 } else {
555                                         if n.state == unknown {
556                                                 net.ping(n, n.addr())
557                                         }
558                                         return nil
559                                 }
560                         })
561
562                 case <-statsDump.C:
563                         log.WithFields(log.Fields{"module": logModule}).Debug("stats dump clock")
564                         /*r, ok := net.ticketStore.radius[testTopic]
565                         if !ok {
566                                 fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now())
567                         } else {
568                                 topics := len(net.ticketStore.tickets)
569                                 tickets := len(net.ticketStore.nodes)
570                                 rad := r.radius / (maxRadius/10000+1)
571                                 fmt.Printf("(%x) topics:%d radius:%d tickets:%d @ %v\n", net.tab.self.ID[:8], topics, rad, tickets, time.Now())
572                         }*/
573
574                         tm := Now()
575                         for topic, r := range net.ticketStore.radius {
576                                 if printTestImgLogs {
577                                         rad := r.radius / (maxRadius/1000000 + 1)
578                                         minrad := r.minRadius / (maxRadius/1000000 + 1)
579                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad)
580                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad)
581                                 }
582                         }
583                         for topic, t := range net.topictab.topics {
584                                 wp := t.wcl.nextWaitPeriod(tm)
585                                 if printTestImgLogs {
586                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000)
587                                 }
588                         }
589
590                 // Periodic / lookup-initiated bucket refresh.
591                 case <-refreshTimer.C:
592                         log.WithFields(log.Fields{"module": logModule}).Debug("refresh timer clock")
593                         // TODO: ideally we would start the refresh timer after
594                         // fallback nodes have been set for the first time.
595                         if refreshDone == nil {
596                                 refreshDone = make(chan struct{})
597                                 net.refresh(refreshDone)
598                         }
599                 case <-bucketRefreshTimer.C:
600                         target := net.tab.chooseBucketRefreshTarget()
601                         go func() {
602                                 net.lookup(target, false)
603                                 bucketRefreshTimer.Reset(bucketRefreshInterval)
604                         }()
605                 case newNursery := <-net.refreshReq:
606                         log.WithFields(log.Fields{"module": logModule}).Debug("net refresh request")
607                         if newNursery != nil {
608                                 net.nursery = newNursery
609                         }
610                         if refreshDone == nil {
611                                 refreshDone = make(chan struct{})
612                                 net.refresh(refreshDone)
613                         }
614                         net.refreshResp <- refreshDone
615                 case <-refreshDone:
616                         log.WithFields(log.Fields{"module": logModule, "table size": net.tab.count}).Debug("net refresh done")
617                         if net.tab.count != 0 {
618                                 refreshDone = nil
619                                 list := searchReqWhenRefreshDone
620                                 searchReqWhenRefreshDone = nil
621                                 go func() {
622                                         for _, req := range list {
623                                                 net.topicSearchReq <- req
624                                         }
625                                 }()
626                         } else {
627                                 refreshDone = make(chan struct{})
628                                 net.refresh(refreshDone)
629                         }
630                 }
631         }
632         log.WithFields(log.Fields{"module": logModule}).Debug("loop stopped,shutting down")
633         if net.conn != nil {
634                 net.conn.Close()
635         }
636         if refreshDone != nil {
637                 // TODO: wait for pending refresh.
638                 //<-refreshResults
639         }
640         // Cancel all pending timeouts.
641         for _, timer := range net.timeoutTimers {
642                 timer.Stop()
643         }
644         if net.db != nil {
645                 net.db.close()
646         }
647         close(net.closed)
648 }
649
650 // Everything below runs on the Network.loop goroutine
651 // and can modify Node, Table and Network at any time without locking.
652
653 func (net *Network) refresh(done chan<- struct{}) {
654         var seeds []*Node
655         if net.db != nil {
656                 seeds = net.db.querySeeds(seedCount, seedMaxAge)
657         }
658         if len(seeds) == 0 {
659                 seeds = net.nursery
660         }
661         if len(seeds) == 0 {
662                 log.WithFields(log.Fields{"module": logModule}).Debug("no seed nodes found")
663                 time.AfterFunc(time.Second*10, func() { close(done) })
664                 return
665         }
666         for _, n := range seeds {
667                 n = net.internNodeFromDB(n)
668                 if n.state == unknown {
669                         net.transition(n, verifyinit)
670                 }
671                 // Force-add the seed node so Lookup does something.
672                 // It will be deleted again if verification fails.
673                 net.tab.add(n)
674         }
675         // Start self lookup to fill up the buckets.
676         go func() {
677                 net.Lookup(net.tab.self.ID)
678                 close(done)
679         }()
680 }
681
682 // Node Interning.
683
684 func (net *Network) internNode(pkt *ingressPacket) *Node {
685         if n := net.nodes[pkt.remoteID]; n != nil {
686                 n.IP = pkt.remoteAddr.IP
687                 n.UDP = uint16(pkt.remoteAddr.Port)
688                 n.TCP = uint16(pkt.remoteAddr.Port)
689                 return n
690         }
691         n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port))
692         n.state = unknown
693         net.nodes[pkt.remoteID] = n
694         return n
695 }
696
697 func (net *Network) internNodeFromDB(dbn *Node) *Node {
698         if n := net.nodes[dbn.ID]; n != nil {
699                 return n
700         }
701         n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP)
702         n.state = unknown
703         net.nodes[n.ID] = n
704         return n
705 }
706
707 func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
708         if rn.ID == net.tab.self.ID {
709                 return nil, errors.New("is self")
710         }
711         if rn.UDP <= lowPort {
712                 return nil, errors.New("low port")
713         }
714         n = net.nodes[rn.ID]
715         if n == nil {
716                 // We haven't seen this node before.
717                 n, err = nodeFromRPC(sender, rn)
718                 if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
719                         return n, errors.New("not contained in netrestrict whitelist")
720                 }
721                 if err == nil {
722                         n.state = unknown
723                         net.nodes[n.ID] = n
724                 }
725                 return n, err
726         }
727         if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP {
728                 if n.state == known {
729                         // reject address change if node is known by us
730                         err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
731                 } else {
732                         // accept otherwise; this will be handled nicer with signed ENRs
733                         n.IP = rn.IP
734                         n.UDP = rn.UDP
735                         n.TCP = rn.TCP
736                 }
737         }
738         return n, err
739 }
740
741 // nodeNetGuts is embedded in Node and contains fields.
742 type nodeNetGuts struct {
743         // This is a cached copy of sha3(ID) which is used for node
744         // distance calculations. This is part of Node in order to make it
745         // possible to write tests that need a node at a certain distance.
746         // In those tests, the content of sha will not actually correspond
747         // with ID.
748         sha common.Hash
749
750         // State machine fields. Access to these fields
751         // is restricted to the Network.loop goroutine.
752         state             *nodeState
753         pingEcho          []byte           // hash of last ping sent by us
754         pingTopics        []Topic          // topic set sent by us in last ping
755         deferredQueries   []*findnodeQuery // queries that can't be sent yet
756         pendingNeighbours *findnodeQuery   // current query, waiting for reply
757         queryTimeouts     int
758 }
759
760 func (n *nodeNetGuts) deferQuery(q *findnodeQuery) {
761         n.deferredQueries = append(n.deferredQueries, q)
762 }
763
764 func (n *nodeNetGuts) startNextQuery(net *Network) {
765         if len(n.deferredQueries) == 0 {
766                 return
767         }
768         nextq := n.deferredQueries[0]
769         if nextq.start(net) {
770                 n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...)
771         }
772 }
773
774 func (q *findnodeQuery) start(net *Network) bool {
775         // Satisfy queries against the local node directly.
776         if q.remote == net.tab.self {
777                 log.WithFields(log.Fields{"module": logModule}).Debug("findnodeQuery self")
778                 closest := net.tab.closest(common.BytesToHash(q.target[:]), bucketSize)
779
780                 q.reply <- closest.entries
781                 return true
782         }
783         if q.remote.state.canQuery && q.remote.pendingNeighbours == nil {
784                 log.WithFields(log.Fields{"module": logModule, "remote peer": q.remote.ID, "targetID": q.target}).Debug("find node query")
785                 net.conn.sendFindnodeHash(q.remote, q.target)
786                 net.timedEvent(respTimeout, q.remote, neighboursTimeout)
787                 q.remote.pendingNeighbours = q
788                 return true
789         }
790         // If the node is not known yet, it won't accept queries.
791         // Initiate the transition to known.
792         // The request will be sent later when the node reaches known state.
793         if q.remote.state == unknown {
794                 log.WithFields(log.Fields{"module": logModule, "id": q.remote.ID, "status": "unknown->verify init"}).Debug("find node query")
795                 net.transition(q.remote, verifyinit)
796         }
797         return false
798 }
799
800 // Node Events (the input to the state machine).
801
802 type nodeEvent uint
803
804 //go:generate stringer -type=nodeEvent
805
806 const (
807         invalidEvent nodeEvent = iota // zero is reserved
808
809         // Packet type events.
810         // These correspond to packet types in the UDP protocol.
811         pingPacket
812         pongPacket
813         findnodePacket
814         neighborsPacket
815         findnodeHashPacket
816         topicRegisterPacket
817         topicQueryPacket
818         topicNodesPacket
819
820         // Non-packet events.
821         // Event values in this category are allocated outside
822         // the packet type range (packet types are encoded as a single byte).
823         pongTimeout nodeEvent = iota + 256
824         pingTimeout
825         neighboursTimeout
826 )
827
828 // Node State Machine.
829
830 type nodeState struct {
831         name     string
832         handle   func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error)
833         enter    func(*Network, *Node)
834         canQuery bool
835 }
836
837 func (s *nodeState) String() string {
838         return s.name
839 }
840
841 var (
842         unknown          *nodeState
843         verifyinit       *nodeState
844         verifywait       *nodeState
845         remoteverifywait *nodeState
846         known            *nodeState
847         contested        *nodeState
848         unresponsive     *nodeState
849 )
850
851 func init() {
852         unknown = &nodeState{
853                 name: "unknown",
854                 enter: func(net *Network, n *Node) {
855                         net.tab.delete(n)
856                         n.pingEcho = nil
857                         // Abort active queries.
858                         for _, q := range n.deferredQueries {
859                                 q.reply <- nil
860                         }
861                         n.deferredQueries = nil
862                         if n.pendingNeighbours != nil {
863                                 n.pendingNeighbours.reply <- nil
864                                 n.pendingNeighbours = nil
865                         }
866                         n.queryTimeouts = 0
867                 },
868                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
869                         switch ev {
870                         case pingPacket:
871                                 net.handlePing(n, pkt)
872                                 net.ping(n, pkt.remoteAddr)
873                                 return verifywait, nil
874                         default:
875                                 return unknown, errInvalidEvent
876                         }
877                 },
878         }
879
880         verifyinit = &nodeState{
881                 name: "verifyinit",
882                 enter: func(net *Network, n *Node) {
883                         net.ping(n, n.addr())
884                 },
885                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
886                         switch ev {
887                         case pingPacket:
888                                 net.handlePing(n, pkt)
889                                 return verifywait, nil
890                         case pongPacket:
891                                 err := net.handleKnownPong(n, pkt)
892                                 return remoteverifywait, err
893                         case pongTimeout:
894                                 return unknown, nil
895                         default:
896                                 return verifyinit, errInvalidEvent
897                         }
898                 },
899         }
900
901         verifywait = &nodeState{
902                 name: "verifywait",
903                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
904                         switch ev {
905                         case pingPacket:
906                                 net.handlePing(n, pkt)
907                                 return verifywait, nil
908                         case pongPacket:
909                                 err := net.handleKnownPong(n, pkt)
910                                 return known, err
911                         case pongTimeout:
912                                 return unknown, nil
913                         default:
914                                 return verifywait, errInvalidEvent
915                         }
916                 },
917         }
918
919         remoteverifywait = &nodeState{
920                 name: "remoteverifywait",
921                 enter: func(net *Network, n *Node) {
922                         net.timedEvent(respTimeout, n, pingTimeout)
923                 },
924                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
925                         switch ev {
926                         case pingPacket:
927                                 net.handlePing(n, pkt)
928                                 return remoteverifywait, nil
929                         case pingTimeout:
930                                 return known, nil
931                         default:
932                                 return remoteverifywait, errInvalidEvent
933                         }
934                 },
935         }
936
937         known = &nodeState{
938                 name:     "known",
939                 canQuery: true,
940                 enter: func(net *Network, n *Node) {
941                         n.queryTimeouts = 0
942                         n.startNextQuery(net)
943                         // Insert into the table and start revalidation of the last node
944                         // in the bucket if it is full.
945                         last := net.tab.add(n)
946                         if last != nil && last.state == known {
947                                 // TODO: do this asynchronously
948                                 net.transition(last, contested)
949                         }
950                 },
951                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
952                         if err := net.db.updateNode(n); err != nil {
953                                 return known, err
954                         }
955
956                         switch ev {
957                         case pingPacket:
958                                 net.handlePing(n, pkt)
959                                 return known, nil
960                         case pongPacket:
961                                 err := net.handleKnownPong(n, pkt)
962                                 return known, err
963                         default:
964                                 return net.handleQueryEvent(n, ev, pkt)
965                         }
966                 },
967         }
968
969         contested = &nodeState{
970                 name:     "contested",
971                 canQuery: true,
972                 enter: func(net *Network, n *Node) {
973                         n.pingEcho = nil
974                         net.ping(n, n.addr())
975                 },
976                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
977                         switch ev {
978                         case pongPacket:
979                                 // Node is still alive.
980                                 err := net.handleKnownPong(n, pkt)
981                                 return known, err
982                         case pongTimeout:
983                                 net.tab.deleteReplace(n)
984                                 return unresponsive, nil
985                         case pingPacket:
986                                 net.handlePing(n, pkt)
987                                 return contested, nil
988                         default:
989                                 return net.handleQueryEvent(n, ev, pkt)
990                         }
991                 },
992         }
993
994         unresponsive = &nodeState{
995                 name:     "unresponsive",
996                 canQuery: true,
997                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
998                         net.db.deleteNode(n.ID)
999
1000                         switch ev {
1001                         case pingPacket:
1002                                 net.handlePing(n, pkt)
1003                                 return known, nil
1004                         case pongPacket:
1005                                 err := net.handleKnownPong(n, pkt)
1006                                 return known, err
1007                         default:
1008                                 return net.handleQueryEvent(n, ev, pkt)
1009                         }
1010                 },
1011         }
1012 }
1013
1014 // handle processes packets sent by n and events related to n.
1015 func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1016         //fmt.Println("handle", n.addr().String(), n.state, ev)
1017         if pkt != nil {
1018                 if err := net.checkPacket(n, ev, pkt); err != nil {
1019                         //fmt.Println("check err:", err)
1020                         return err
1021                 }
1022                 // Start the background expiration goroutine after the first
1023                 // successful communication. Subsequent calls have no effect if it
1024                 // is already running. We do this here instead of somewhere else
1025                 // so that the search for seed nodes also considers older nodes
1026                 // that would otherwise be removed by the expirer.
1027                 if net.db != nil {
1028                         net.db.ensureExpirer()
1029                 }
1030         }
1031         if n.state == nil {
1032                 n.state = unknown //???
1033         }
1034         next, err := n.state.handle(net, n, ev, pkt)
1035         net.transition(n, next)
1036         //fmt.Println("new state:", n.state)
1037         return err
1038 }
1039
1040 func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1041         // Replay prevention checks.
1042         switch ev {
1043         case pingPacket, findnodeHashPacket, neighborsPacket:
1044                 // TODO: check date is > last date seen
1045                 // TODO: check ping version
1046         case pongPacket:
1047                 if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) {
1048                         // fmt.Println("pong reply token mismatch")
1049                         return fmt.Errorf("pong reply token mismatch")
1050                 }
1051                 n.pingEcho = nil
1052         }
1053         // Address validation.
1054         // TODO: Ideally we would do the following:
1055         //  - reject all packets with wrong address except ping.
1056         //  - for ping with new address, transition to verifywait but keep the
1057         //    previous node (with old address) around. if the new one reaches known,
1058         //    swap it out.
1059         return nil
1060 }
1061
1062 func (net *Network) transition(n *Node, next *nodeState) {
1063         if n.state != next {
1064                 n.state = next
1065                 if next.enter != nil {
1066                         next.enter(net, n)
1067                 }
1068         }
1069
1070         // TODO: persist/unpersist node
1071 }
1072
1073 func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) {
1074         timeout := timeoutEvent{ev, n}
1075         net.timeoutTimers[timeout] = time.AfterFunc(d, func() {
1076                 select {
1077                 case net.timeout <- timeout:
1078                 case <-net.closed:
1079                 }
1080         })
1081 }
1082
1083 func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) {
1084         timer := net.timeoutTimers[timeoutEvent{ev, n}]
1085         if timer != nil {
1086                 timer.Stop()
1087                 delete(net.timeoutTimers, timeoutEvent{ev, n})
1088         }
1089 }
1090
1091 func (net *Network) ping(n *Node, addr *net.UDPAddr) {
1092         //fmt.Println("ping", n.addr().String(), n.ID.String(), n.sha.Hex())
1093         if n.pingEcho != nil || n.ID == net.tab.self.ID {
1094                 //fmt.Println(" not sent")
1095                 return
1096         }
1097         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Pinging remote node")
1098         n.pingTopics = net.ticketStore.regTopicSet()
1099         n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics)
1100         net.timedEvent(respTimeout, n, pongTimeout)
1101 }
1102
1103 func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
1104         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling remote ping")
1105         ping := pkt.data.(*ping)
1106         n.TCP = ping.From.TCP
1107         t := net.topictab.getTicket(n, ping.Topics)
1108
1109         pong := &pong{
1110                 To:         makeEndpoint(n.addr(), n.TCP), // TODO: maybe use known TCP port from DB
1111                 ReplyTok:   pkt.hash,
1112                 Expiration: uint64(time.Now().Add(expiration).Unix()),
1113         }
1114         ticketToPong(t, pong)
1115         net.conn.send(n, pongPacket, pong)
1116 }
1117
1118 func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
1119         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling known pong")
1120         net.abortTimedEvent(n, pongTimeout)
1121         now := Now()
1122         ticket, err := pongToTicket(now, n.pingTopics, n, pkt)
1123         if err == nil {
1124                 // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data)
1125                 net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket)
1126         } else {
1127                 log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Failed to convert pong to ticket")
1128         }
1129         n.pingEcho = nil
1130         n.pingTopics = nil
1131         net.db.updateLastPong(n.ID, time.Now())
1132         return err
1133 }
1134
1135 func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
1136         switch ev {
1137         case findnodePacket:
1138                 target := common.BytesToHash(pkt.data.(*findnode).Target[:])
1139                 results := net.tab.closest(target, bucketSize).entries
1140                 net.conn.sendNeighbours(n, results)
1141                 return n.state, nil
1142         case neighborsPacket:
1143                 err := net.handleNeighboursPacket(n, pkt)
1144                 return n.state, err
1145         case neighboursTimeout:
1146                 if n.pendingNeighbours != nil {
1147                         n.pendingNeighbours.reply <- nil
1148                         n.pendingNeighbours = nil
1149                 }
1150                 n.queryTimeouts++
1151                 if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1152                         return contested, errors.New("too many timeouts")
1153                 }
1154                 return n.state, nil
1155
1156         // v5
1157
1158         case findnodeHashPacket:
1159                 results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries
1160                 net.conn.sendNeighbours(n, results)
1161                 return n.state, nil
1162         case topicRegisterPacket:
1163                 //fmt.Println("got topicRegisterPacket")
1164                 regdata := pkt.data.(*topicRegister)
1165                 pong, err := net.checkTopicRegister(regdata)
1166                 if err != nil {
1167                         //fmt.Println(err)
1168                         return n.state, fmt.Errorf("bad waiting ticket: %v", err)
1169                 }
1170                 net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods)
1171                 return n.state, nil
1172         case topicQueryPacket:
1173                 // TODO: handle expiration
1174                 topic := pkt.data.(*topicQuery).Topic
1175                 results := net.topictab.getEntries(topic)
1176                 if _, ok := net.ticketStore.tickets[topic]; ok {
1177                         results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too
1178                 }
1179                 if len(results) > 10 {
1180                         results = results[:10]
1181                 }
1182                 var hash common.Hash
1183                 copy(hash[:], pkt.hash)
1184                 net.conn.sendTopicNodes(n, hash, results)
1185                 return n.state, nil
1186         case topicNodesPacket:
1187                 p := pkt.data.(*topicNodes)
1188                 if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) {
1189                         n.queryTimeouts++
1190                         if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1191                                 return contested, errors.New("too many timeouts")
1192                         }
1193                 }
1194                 return n.state, nil
1195
1196         default:
1197                 return n.state, errInvalidEvent
1198         }
1199 }
1200
1201 func (net *Network) checkTopicRegister(data *topicRegister) (*pong, error) {
1202         var pongpkt ingressPacket
1203         if err := decodePacket(data.Pong, &pongpkt); err != nil {
1204                 return nil, err
1205         }
1206         if pongpkt.ev != pongPacket {
1207                 return nil, errors.New("is not pong packet")
1208         }
1209         if pongpkt.remoteID != net.tab.self.ID {
1210                 return nil, errors.New("not signed by us")
1211         }
1212         // check that we previously authorised all topics
1213         // that the other side is trying to register.
1214         hash, _, _ := wireHash(data.Topics)
1215         if hash != pongpkt.data.(*pong).TopicHash {
1216                 return nil, errors.New("topic hash mismatch")
1217         }
1218         if int(data.Idx) < 0 || int(data.Idx) >= len(data.Topics) {
1219                 return nil, errors.New("topic index out of range")
1220         }
1221         return pongpkt.data.(*pong), nil
1222 }
1223
1224 func wireHash(x interface{}) (h common.Hash, n int, err error) {
1225         hw := sha3.New256()
1226         wire.WriteBinary(x, hw, &n, &err)
1227         hw.Sum(h[:0])
1228         return h, n, err
1229 }
1230
1231 func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
1232         if n.pendingNeighbours == nil {
1233                 return errNoQuery
1234         }
1235         net.abortTimedEvent(n, neighboursTimeout)
1236
1237         req := pkt.data.(*neighbors)
1238         nodes := make([]*Node, len(req.Nodes))
1239         for i, rn := range req.Nodes {
1240                 nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
1241                 if err != nil {
1242                         log.WithFields(log.Fields{"module": logModule, "ip": rn.IP, "id:": n.ID[:8], "addr:": pkt.remoteAddr, "error": err}).Debug("invalid neighbour")
1243                         continue
1244                 }
1245                 nodes[i] = nn
1246                 // Start validation of query results immediately.
1247                 // This fills the table quickly.
1248                 // TODO: generates way too many packets, maybe do it via queue.
1249                 if nn.state == unknown {
1250                         net.transition(nn, verifyinit)
1251                 }
1252         }
1253         // TODO: don't ignore second packet
1254         n.pendingNeighbours.reply <- nodes
1255         n.pendingNeighbours = nil
1256         // Now that this query is done, start the next one.
1257         n.startNextQuery(net)
1258         return nil
1259 }