OSDN Git Service

ee7aea1071a27195e50842b849e1153ab4ae5822
[bytom/vapor.git] / p2p / discover / dht / net.go
1 package dht
2
3 import (
4         "bytes"
5         "encoding/hex"
6         "errors"
7         "fmt"
8         "net"
9         "time"
10
11         log "github.com/sirupsen/logrus"
12         "github.com/tendermint/go-wire"
13         "golang.org/x/crypto/sha3"
14
15         "github.com/vapor/common"
16         "github.com/vapor/crypto/ed25519"
17         "github.com/vapor/p2p/netutil"
18 )
19
20 var (
21         errInvalidEvent = errors.New("invalid in current state")
22         errNoQuery      = errors.New("no pending query")
23         errWrongAddress = errors.New("unknown sender address")
24 )
25
26 const (
27         autoRefreshInterval   = 1 * time.Hour
28         bucketRefreshInterval = 1 * time.Minute
29         seedCount             = 30
30         seedMaxAge            = 5 * 24 * time.Hour
31         lowPort               = 1024
32 )
33
34 const (
35         printTestImgLogs = false
36 )
37
38 // Network manages the table and all protocol interaction.
39 type Network struct {
40         db          *nodeDB // database of known nodes
41         conn        transport
42         netrestrict *netutil.Netlist
43
44         closed           chan struct{}          // closed when loop is done
45         closeReq         chan struct{}          // 'request to close'
46         refreshReq       chan []*Node           // lookups ask for refresh on this channel
47         refreshResp      chan (<-chan struct{}) // ...and get the channel to block on from this one
48         read             chan ingressPacket     // ingress packets arrive here
49         timeout          chan timeoutEvent
50         queryReq         chan *findnodeQuery // lookups submit findnode queries on this channel
51         tableOpReq       chan func()
52         tableOpResp      chan struct{}
53         topicRegisterReq chan topicRegisterReq
54         topicSearchReq   chan topicSearchReq
55
56         // State of the main loop.
57         tab           *Table
58         topictab      *topicTable
59         ticketStore   *ticketStore
60         nursery       []*Node
61         nodes         map[NodeID]*Node // tracks active nodes with state != known
62         timeoutTimers map[timeoutEvent]*time.Timer
63
64         // Revalidation queues.
65         // Nodes put on these queues will be pinged eventually.
66         slowRevalidateQueue []*Node
67         fastRevalidateQueue []*Node
68
69         // Buffers for state transition.
70         sendBuf []*ingressPacket
71 }
72
73 // transport is implemented by the UDP transport.
74 // it is an interface so we can test without opening lots of UDP
75 // sockets and without generating a private key.
76 type transport interface {
77         sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte)
78         sendNeighbours(remote *Node, nodes []*Node)
79         sendFindnodeHash(remote *Node, target common.Hash)
80         sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte)
81         sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
82
83         send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte)
84
85         localAddr() *net.UDPAddr
86         getNetID() uint64
87         Close()
88 }
89
90 type findnodeQuery struct {
91         remote   *Node
92         target   common.Hash
93         reply    chan<- []*Node
94         nresults int // counter for received nodes
95 }
96
97 type topicRegisterReq struct {
98         add   bool
99         topic Topic
100 }
101
102 type topicSearchReq struct {
103         topic  Topic
104         found  chan<- *Node
105         lookup chan<- bool
106         delay  time.Duration
107 }
108
109 type topicSearchResult struct {
110         target lookupInfo
111         nodes  []*Node
112 }
113
114 type timeoutEvent struct {
115         ev   nodeEvent
116         node *Node
117 }
118
119 func newNetwork(conn transport, ourPubkey ed25519.PublicKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
120         var ourID NodeID
121         copy(ourID[:], ourPubkey[:nodeIDBits])
122
123         var db *nodeDB
124         if dbPath != "<no database>" {
125                 var err error
126                 if db, err = newNodeDB(dbPath, Version, ourID); err != nil {
127                         return nil, err
128                 }
129         }
130
131         tab := newTable(ourID, conn.localAddr())
132         net := &Network{
133                 db:               db,
134                 conn:             conn,
135                 netrestrict:      netrestrict,
136                 tab:              tab,
137                 topictab:         newTopicTable(db, tab.self),
138                 ticketStore:      newTicketStore(),
139                 refreshReq:       make(chan []*Node),
140                 refreshResp:      make(chan (<-chan struct{})),
141                 closed:           make(chan struct{}),
142                 closeReq:         make(chan struct{}),
143                 read:             make(chan ingressPacket, 100),
144                 timeout:          make(chan timeoutEvent),
145                 timeoutTimers:    make(map[timeoutEvent]*time.Timer),
146                 tableOpReq:       make(chan func()),
147                 tableOpResp:      make(chan struct{}),
148                 queryReq:         make(chan *findnodeQuery),
149                 topicRegisterReq: make(chan topicRegisterReq),
150                 topicSearchReq:   make(chan topicSearchReq),
151                 nodes:            make(map[NodeID]*Node),
152         }
153         go net.loop()
154         return net, nil
155 }
156
157 // Close terminates the network listener and flushes the node database.
158 func (net *Network) Close() {
159         net.conn.Close()
160         select {
161         case <-net.closed:
162         case net.closeReq <- struct{}{}:
163                 <-net.closed
164         }
165 }
166
167 // Self returns the local node.
168 // The returned node should not be modified by the caller.
169 func (net *Network) Self() *Node {
170         return net.tab.self
171 }
172
173 func (net *Network) selfIP() net.IP {
174         return net.tab.self.IP
175 }
176
177 // ReadRandomNodes fills the given slice with random nodes from the
178 // table. It will not write the same node more than once. The nodes in
179 // the slice are copies and can be modified by the caller.
180 func (net *Network) ReadRandomNodes(buf []*Node) (n int) {
181         net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) })
182         return n
183 }
184
185 // SetFallbackNodes sets the initial points of contact. These nodes
186 // are used to connect to the network if the table is empty and there
187 // are no known nodes in the database.
188 func (net *Network) SetFallbackNodes(nodes []*Node) error {
189         nursery := make([]*Node, 0, len(nodes))
190         for _, n := range nodes {
191                 if err := n.validateComplete(); err != nil {
192                         return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
193                 }
194                 // Recompute cpy.sha because the node might not have been
195                 // created by NewNode or ParseNode.
196                 cpy := *n
197                 cpy.sha = common.BytesToHash(n.ID[:])
198                 nursery = append(nursery, &cpy)
199         }
200         net.reqRefresh(nursery)
201         return nil
202 }
203
204 // Resolve searches for a specific node with the given ID.
205 // It returns nil if the node could not be found.
206 func (net *Network) Resolve(targetID NodeID) *Node {
207         result := net.lookup(common.BytesToHash(targetID[:]), true)
208         for _, n := range result {
209                 if n.ID == targetID {
210                         return n
211                 }
212         }
213         return nil
214 }
215
216 // Lookup performs a network search for nodes close
217 // to the given target. It approaches the target by querying
218 // nodes that are closer to it on each iteration.
219 // The given target does not need to be an actual node
220 // identifier.
221 //
222 // The local node may be included in the result.
223 func (net *Network) Lookup(targetID NodeID) []*Node {
224         return net.lookup(common.BytesToHash(targetID[:]), false)
225 }
226
227 func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node {
228         var (
229                 asked          = make(map[NodeID]bool)
230                 seen           = make(map[NodeID]bool)
231                 reply          = make(chan []*Node, alpha)
232                 result         = nodesByDistance{target: target}
233                 pendingQueries = 0
234         )
235         // Get initial answers from the local node.
236         result.push(net.tab.self, bucketSize)
237         for {
238                 // Ask the Î± closest nodes that we haven't asked yet.
239                 for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
240                         n := result.entries[i]
241                         if !asked[n.ID] {
242                                 asked[n.ID] = true
243                                 pendingQueries++
244                                 net.reqQueryFindnode(n, target, reply)
245                         }
246                 }
247                 if pendingQueries == 0 {
248                         // We have asked all closest nodes, stop the search.
249                         break
250                 }
251                 // Wait for the next reply.
252                 select {
253                 case nodes := <-reply:
254                         for _, n := range nodes {
255                                 if n != nil && !seen[n.ID] {
256                                         seen[n.ID] = true
257                                         result.push(n, bucketSize)
258                                         if stopOnMatch && n.sha == target {
259                                                 return result.entries
260                                         }
261                                 }
262                         }
263                         pendingQueries--
264                 case <-time.After(respTimeout):
265                         // forget all pending requests, start new ones
266                         pendingQueries = 0
267                         reply = make(chan []*Node, alpha)
268                 }
269         }
270         return result.entries
271 }
272
273 func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) {
274         select {
275         case net.topicRegisterReq <- topicRegisterReq{true, topic}:
276         case <-net.closed:
277                 return
278         }
279         select {
280         case <-net.closed:
281         case <-stop:
282                 select {
283                 case net.topicRegisterReq <- topicRegisterReq{false, topic}:
284                 case <-net.closed:
285                 }
286         }
287 }
288
289 func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) {
290         for {
291                 select {
292                 case <-net.closed:
293                         return
294                 case delay, ok := <-setPeriod:
295                         select {
296                         case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}:
297                         case <-net.closed:
298                                 return
299                         }
300                         if !ok {
301                                 return
302                         }
303                 }
304         }
305 }
306
307 func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} {
308         select {
309         case net.refreshReq <- nursery:
310                 return <-net.refreshResp
311         case <-net.closed:
312                 return net.closed
313         }
314 }
315
316 func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool {
317         q := &findnodeQuery{remote: n, target: target, reply: reply}
318         select {
319         case net.queryReq <- q:
320                 return true
321         case <-net.closed:
322                 return false
323         }
324 }
325
326 func (net *Network) reqReadPacket(pkt ingressPacket) {
327         select {
328         case net.read <- pkt:
329         case <-net.closed:
330         }
331 }
332
333 func (net *Network) reqTableOp(f func()) (called bool) {
334         select {
335         case net.tableOpReq <- f:
336                 <-net.tableOpResp
337                 return true
338         case <-net.closed:
339                 return false
340         }
341 }
342
343 // TODO: external address handling.
344
345 type topicSearchInfo struct {
346         lookupChn chan<- bool
347         period    time.Duration
348 }
349
350 const maxSearchCount = 5
351
352 func (net *Network) loop() {
353         var (
354                 refreshTimer       = time.NewTicker(autoRefreshInterval)
355                 bucketRefreshTimer = time.NewTimer(bucketRefreshInterval)
356                 refreshDone        chan struct{} // closed when the 'refresh' lookup has ended
357         )
358
359         // Tracking the next ticket to register.
360         var (
361                 nextTicket        *ticketRef
362                 nextRegisterTimer *time.Timer
363                 nextRegisterTime  <-chan time.Time
364         )
365         defer func() {
366                 if nextRegisterTimer != nil {
367                         nextRegisterTimer.Stop()
368                 }
369                 refreshTimer.Stop()
370                 bucketRefreshTimer.Stop()
371         }()
372         resetNextTicket := func() {
373                 ticket, timeout := net.ticketStore.nextFilteredTicket()
374                 if nextTicket != ticket {
375                         nextTicket = ticket
376                         if nextRegisterTimer != nil {
377                                 nextRegisterTimer.Stop()
378                                 nextRegisterTime = nil
379                         }
380                         if ticket != nil {
381                                 nextRegisterTimer = time.NewTimer(timeout)
382                                 nextRegisterTime = nextRegisterTimer.C
383                         }
384                 }
385         }
386
387         // Tracking registration and search lookups.
388         var (
389                 topicRegisterLookupTarget lookupInfo
390                 topicRegisterLookupDone   chan []*Node
391                 topicRegisterLookupTick   = time.NewTimer(0)
392                 searchReqWhenRefreshDone  []topicSearchReq
393                 searchInfo                = make(map[Topic]topicSearchInfo)
394                 activeSearchCount         int
395         )
396         topicSearchLookupDone := make(chan topicSearchResult, 100)
397         topicSearch := make(chan Topic, 100)
398         <-topicRegisterLookupTick.C
399
400         statsDump := time.NewTicker(10 * time.Second)
401         defer statsDump.Stop()
402
403 loop:
404         for {
405                 resetNextTicket()
406
407                 select {
408                 case <-net.closeReq:
409                         log.WithFields(log.Fields{"module": logModule}).Debug("close request")
410                         break loop
411
412                 // Ingress packet handling.
413                 case pkt := <-net.read:
414                         log.WithFields(log.Fields{"module": logModule}).Debug("read from net")
415                         n := net.internNode(&pkt)
416                         prestate := n.state
417                         status := "ok"
418                         if err := net.handle(n, pkt.ev, &pkt); err != nil {
419                                 status = err.Error()
420                         }
421                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": pkt.ev, "remote id": hex.EncodeToString(pkt.remoteID[:8]), "remote addr": pkt.remoteAddr, "pre state": prestate, "node state": n.state, "status": status}).Debug("handle ingress msg")
422
423                         // TODO: persist state if n.state goes >= known, delete if it goes <= known
424
425                 // State transition timeouts.
426                 case timeout := <-net.timeout:
427                         log.WithFields(log.Fields{"module": logModule}).Debug("net timeout")
428                         if net.timeoutTimers[timeout] == nil {
429                                 // Stale timer (was aborted).
430                                 continue
431                         }
432                         delete(net.timeoutTimers, timeout)
433                         prestate := timeout.node.state
434                         status := "ok"
435                         if err := net.handle(timeout.node, timeout.ev, nil); err != nil {
436                                 status = err.Error()
437                         }
438                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": timeout.ev, "node id": hex.EncodeToString(timeout.node.ID[:8]), "node addr": timeout.node.addr(), "pre state": prestate, "node state": timeout.node.state, "status": status}).Debug("handle timeout")
439
440                 // Querying.
441                 case q := <-net.queryReq:
442                         log.WithFields(log.Fields{"module": logModule}).Debug("net query request")
443                         if !q.start(net) {
444                                 q.remote.deferQuery(q)
445                         }
446
447                 // Interacting with the table.
448                 case f := <-net.tableOpReq:
449                         log.WithFields(log.Fields{"module": logModule}).Debug("net table operate request")
450                         f()
451                         net.tableOpResp <- struct{}{}
452
453                 // Topic registration stuff.
454                 case req := <-net.topicRegisterReq:
455                         log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic register request")
456                         if !req.add {
457                                 net.ticketStore.removeRegisterTopic(req.topic)
458                                 continue
459                         }
460                         net.ticketStore.addTopic(req.topic, true)
461                         // If we're currently waiting idle (nothing to look up), give the ticket store a
462                         // chance to start it sooner. This should speed up convergence of the radius
463                         // determination for new topics.
464                         // if topicRegisterLookupDone == nil {
465                         if topicRegisterLookupTarget.target == (common.Hash{}) {
466                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("topic register lookup target null")
467                                 if topicRegisterLookupTick.Stop() {
468                                         <-topicRegisterLookupTick.C
469                                 }
470                                 target, delay := net.ticketStore.nextRegisterLookup()
471                                 topicRegisterLookupTarget = target
472                                 topicRegisterLookupTick.Reset(delay)
473                         }
474
475                 case nodes := <-topicRegisterLookupDone:
476                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup done")
477                         net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte {
478                                 net.ping(n, n.addr())
479                                 return n.pingEcho
480                         })
481                         target, delay := net.ticketStore.nextRegisterLookup()
482                         topicRegisterLookupTarget = target
483                         topicRegisterLookupTick.Reset(delay)
484                         topicRegisterLookupDone = nil
485
486                 case <-topicRegisterLookupTick.C:
487                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup tick")
488                         if (topicRegisterLookupTarget.target == common.Hash{}) {
489                                 target, delay := net.ticketStore.nextRegisterLookup()
490                                 topicRegisterLookupTarget = target
491                                 topicRegisterLookupTick.Reset(delay)
492                                 topicRegisterLookupDone = nil
493                         } else {
494                                 topicRegisterLookupDone = make(chan []*Node)
495                                 target := topicRegisterLookupTarget.target
496                                 go func() { topicRegisterLookupDone <- net.lookup(target, false) }()
497                         }
498
499                 case <-nextRegisterTime:
500                         log.WithFields(log.Fields{"module": logModule}).Debug("next register time")
501                         net.ticketStore.ticketRegistered(*nextTicket)
502                         net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
503
504                 case req := <-net.topicSearchReq:
505                         if refreshDone == nil {
506                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic rearch req")
507                                 info, ok := searchInfo[req.topic]
508                                 if ok {
509                                         if req.delay == time.Duration(0) {
510                                                 delete(searchInfo, req.topic)
511                                                 net.ticketStore.removeSearchTopic(req.topic)
512                                         } else {
513                                                 info.period = req.delay
514                                                 searchInfo[req.topic] = info
515                                         }
516                                         continue
517                                 }
518                                 if req.delay != time.Duration(0) {
519                                         var info topicSearchInfo
520                                         info.period = req.delay
521                                         info.lookupChn = req.lookup
522                                         searchInfo[req.topic] = info
523                                         net.ticketStore.addSearchTopic(req.topic, req.found)
524                                         topicSearch <- req.topic
525                                 }
526                         } else {
527                                 searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req)
528                         }
529
530                 case topic := <-topicSearch:
531                         if activeSearchCount < maxSearchCount {
532                                 activeSearchCount++
533                                 target := net.ticketStore.nextSearchLookup(topic)
534                                 go func() {
535                                         nodes := net.lookup(target.target, false)
536                                         topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes}
537                                 }()
538                         }
539                         period := searchInfo[topic].period
540                         if period != time.Duration(0) {
541                                 go func() {
542                                         time.Sleep(period)
543                                         topicSearch <- topic
544                                 }()
545                         }
546
547                 case res := <-topicSearchLookupDone:
548                         activeSearchCount--
549                         if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil {
550                                 lookupChn <- net.ticketStore.radius[res.target.topic].converged
551                         }
552                         net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte {
553                                 if n.state != nil && n.state.canQuery {
554                                         return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration
555                                 } else {
556                                         if n.state == unknown {
557                                                 net.ping(n, n.addr())
558                                         }
559                                         return nil
560                                 }
561                         })
562
563                 case <-statsDump.C:
564                         log.WithFields(log.Fields{"module": logModule}).Debug("stats dump clock")
565                         /*r, ok := net.ticketStore.radius[testTopic]
566                         if !ok {
567                                 fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now())
568                         } else {
569                                 topics := len(net.ticketStore.tickets)
570                                 tickets := len(net.ticketStore.nodes)
571                                 rad := r.radius / (maxRadius/10000+1)
572                                 fmt.Printf("(%x) topics:%d radius:%d tickets:%d @ %v\n", net.tab.self.ID[:8], topics, rad, tickets, time.Now())
573                         }*/
574
575                         tm := Now()
576                         for topic, r := range net.ticketStore.radius {
577                                 if printTestImgLogs {
578                                         rad := r.radius / (maxRadius/1000000 + 1)
579                                         minrad := r.minRadius / (maxRadius/1000000 + 1)
580                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad)
581                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad)
582                                 }
583                         }
584                         for topic, t := range net.topictab.topics {
585                                 wp := t.wcl.nextWaitPeriod(tm)
586                                 if printTestImgLogs {
587                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000)
588                                 }
589                         }
590
591                 // Periodic / lookup-initiated bucket refresh.
592                 case <-refreshTimer.C:
593                         log.WithFields(log.Fields{"module": logModule}).Debug("refresh timer clock")
594                         // TODO: ideally we would start the refresh timer after
595                         // fallback nodes have been set for the first time.
596                         if refreshDone == nil {
597                                 refreshDone = make(chan struct{})
598                                 net.refresh(refreshDone)
599                         }
600                 case <-bucketRefreshTimer.C:
601                         target := net.tab.chooseBucketRefreshTarget()
602                         go func() {
603                                 net.lookup(target, false)
604                                 bucketRefreshTimer.Reset(bucketRefreshInterval)
605                         }()
606                 case newNursery := <-net.refreshReq:
607                         log.WithFields(log.Fields{"module": logModule}).Debug("net refresh request")
608                         if newNursery != nil {
609                                 net.nursery = newNursery
610                         }
611                         if refreshDone == nil {
612                                 refreshDone = make(chan struct{})
613                                 net.refresh(refreshDone)
614                         }
615                         net.refreshResp <- refreshDone
616                 case <-refreshDone:
617                         log.WithFields(log.Fields{"module": logModule, "table size": net.tab.count}).Debug("net refresh done")
618                         if net.tab.count != 0 {
619                                 refreshDone = nil
620                                 list := searchReqWhenRefreshDone
621                                 searchReqWhenRefreshDone = nil
622                                 go func() {
623                                         for _, req := range list {
624                                                 net.topicSearchReq <- req
625                                         }
626                                 }()
627                         } else {
628                                 refreshDone = make(chan struct{})
629                                 net.refresh(refreshDone)
630                         }
631                 }
632         }
633         log.WithFields(log.Fields{"module": logModule}).Debug("loop stopped,shutting down")
634         if net.conn != nil {
635                 net.conn.Close()
636         }
637         if refreshDone != nil {
638                 // TODO: wait for pending refresh.
639                 //<-refreshResults
640         }
641         // Cancel all pending timeouts.
642         for _, timer := range net.timeoutTimers {
643                 timer.Stop()
644         }
645         if net.db != nil {
646                 net.db.close()
647         }
648         close(net.closed)
649 }
650
651 // Everything below runs on the Network.loop goroutine
652 // and can modify Node, Table and Network at any time without locking.
653
654 func (net *Network) refresh(done chan<- struct{}) {
655         var seeds []*Node
656         if net.db != nil {
657                 seeds = net.db.querySeeds(seedCount, seedMaxAge)
658         }
659         if len(seeds) == 0 {
660                 seeds = net.nursery
661         }
662         if len(seeds) == 0 {
663                 log.WithFields(log.Fields{"module": logModule}).Debug("no seed nodes found")
664                 time.AfterFunc(time.Second*10, func() { close(done) })
665                 return
666         }
667         for _, n := range seeds {
668                 n = net.internNodeFromDB(n)
669                 if n.state == unknown {
670                         net.transition(n, verifyinit)
671                 }
672                 // Force-add the seed node so Lookup does something.
673                 // It will be deleted again if verification fails.
674                 net.tab.add(n)
675         }
676         // Start self lookup to fill up the buckets.
677         go func() {
678                 net.Lookup(net.tab.self.ID)
679                 close(done)
680         }()
681 }
682
683 // Node Interning.
684
685 func (net *Network) internNode(pkt *ingressPacket) *Node {
686         if n := net.nodes[pkt.remoteID]; n != nil {
687                 n.IP = pkt.remoteAddr.IP
688                 n.UDP = uint16(pkt.remoteAddr.Port)
689                 n.TCP = uint16(pkt.remoteAddr.Port)
690                 return n
691         }
692         n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port))
693         n.state = unknown
694         net.nodes[pkt.remoteID] = n
695         return n
696 }
697
698 func (net *Network) internNodeFromDB(dbn *Node) *Node {
699         if n := net.nodes[dbn.ID]; n != nil {
700                 return n
701         }
702         n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP)
703         n.state = unknown
704         net.nodes[n.ID] = n
705         return n
706 }
707
708 func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
709         if rn.ID == net.tab.self.ID {
710                 return nil, errors.New("is self")
711         }
712         if rn.UDP <= lowPort {
713                 return nil, errors.New("low port")
714         }
715         n = net.nodes[rn.ID]
716         if n == nil {
717                 // We haven't seen this node before.
718                 n, err = nodeFromRPC(sender, rn)
719                 if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
720                         return n, errors.New("not contained in netrestrict whitelist")
721                 }
722                 if err == nil {
723                         n.state = unknown
724                         net.nodes[n.ID] = n
725                 }
726                 return n, err
727         }
728         if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP {
729                 if n.state == known {
730                         // reject address change if node is known by us
731                         err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
732                 } else {
733                         // accept otherwise; this will be handled nicer with signed ENRs
734                         n.IP = rn.IP
735                         n.UDP = rn.UDP
736                         n.TCP = rn.TCP
737                 }
738         }
739         return n, err
740 }
741
742 // nodeNetGuts is embedded in Node and contains fields.
743 type nodeNetGuts struct {
744         // This is a cached copy of sha3(ID) which is used for node
745         // distance calculations. This is part of Node in order to make it
746         // possible to write tests that need a node at a certain distance.
747         // In those tests, the content of sha will not actually correspond
748         // with ID.
749         sha common.Hash
750
751         // State machine fields. Access to these fields
752         // is restricted to the Network.loop goroutine.
753         state             *nodeState
754         pingEcho          []byte           // hash of last ping sent by us
755         pingTopics        []Topic          // topic set sent by us in last ping
756         deferredQueries   []*findnodeQuery // queries that can't be sent yet
757         pendingNeighbours *findnodeQuery   // current query, waiting for reply
758         queryTimeouts     int
759 }
760
761 func (n *nodeNetGuts) deferQuery(q *findnodeQuery) {
762         n.deferredQueries = append(n.deferredQueries, q)
763 }
764
765 func (n *nodeNetGuts) startNextQuery(net *Network) {
766         if len(n.deferredQueries) == 0 {
767                 return
768         }
769         nextq := n.deferredQueries[0]
770         if nextq.start(net) {
771                 n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...)
772         }
773 }
774
775 func (q *findnodeQuery) start(net *Network) bool {
776         // Satisfy queries against the local node directly.
777         if q.remote == net.tab.self {
778                 log.WithFields(log.Fields{"module": logModule}).Debug("findnodeQuery self")
779                 closest := net.tab.closest(common.BytesToHash(q.target[:]), bucketSize)
780
781                 q.reply <- closest.entries
782                 return true
783         }
784         if q.remote.state.canQuery && q.remote.pendingNeighbours == nil {
785                 log.WithFields(log.Fields{"module": logModule, "remote peer": q.remote.ID, "targetID": q.target}).Debug("find node query")
786                 net.conn.sendFindnodeHash(q.remote, q.target)
787                 net.timedEvent(respTimeout, q.remote, neighboursTimeout)
788                 q.remote.pendingNeighbours = q
789                 return true
790         }
791         // If the node is not known yet, it won't accept queries.
792         // Initiate the transition to known.
793         // The request will be sent later when the node reaches known state.
794         if q.remote.state == unknown {
795                 log.WithFields(log.Fields{"module": logModule, "id": q.remote.ID, "status": "unknown->verify init"}).Debug("find node query")
796                 net.transition(q.remote, verifyinit)
797         }
798         return false
799 }
800
801 // Node Events (the input to the state machine).
802
803 type nodeEvent uint
804
805 //go:generate stringer -type=nodeEvent
806
807 const (
808         invalidEvent nodeEvent = iota // zero is reserved
809
810         // Packet type events.
811         // These correspond to packet types in the UDP protocol.
812         pingPacket
813         pongPacket
814         findnodePacket
815         neighborsPacket
816         findnodeHashPacket
817         topicRegisterPacket
818         topicQueryPacket
819         topicNodesPacket
820
821         // Non-packet events.
822         // Event values in this category are allocated outside
823         // the packet type range (packet types are encoded as a single byte).
824         pongTimeout nodeEvent = iota + 256
825         pingTimeout
826         neighboursTimeout
827 )
828
829 // Node State Machine.
830
831 type nodeState struct {
832         name     string
833         handle   func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error)
834         enter    func(*Network, *Node)
835         canQuery bool
836 }
837
838 func (s *nodeState) String() string {
839         return s.name
840 }
841
842 var (
843         unknown          *nodeState
844         verifyinit       *nodeState
845         verifywait       *nodeState
846         remoteverifywait *nodeState
847         known            *nodeState
848         contested        *nodeState
849         unresponsive     *nodeState
850 )
851
852 func init() {
853         unknown = &nodeState{
854                 name: "unknown",
855                 enter: func(net *Network, n *Node) {
856                         net.tab.delete(n)
857                         n.pingEcho = nil
858                         // Abort active queries.
859                         for _, q := range n.deferredQueries {
860                                 q.reply <- nil
861                         }
862                         n.deferredQueries = nil
863                         if n.pendingNeighbours != nil {
864                                 n.pendingNeighbours.reply <- nil
865                                 n.pendingNeighbours = nil
866                         }
867                         n.queryTimeouts = 0
868                 },
869                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
870                         switch ev {
871                         case pingPacket:
872                                 net.handlePing(n, pkt)
873                                 net.ping(n, pkt.remoteAddr)
874                                 return verifywait, nil
875                         default:
876                                 return unknown, errInvalidEvent
877                         }
878                 },
879         }
880
881         verifyinit = &nodeState{
882                 name: "verifyinit",
883                 enter: func(net *Network, n *Node) {
884                         net.ping(n, n.addr())
885                 },
886                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
887                         switch ev {
888                         case pingPacket:
889                                 net.handlePing(n, pkt)
890                                 return verifywait, nil
891                         case pongPacket:
892                                 err := net.handleKnownPong(n, pkt)
893                                 return remoteverifywait, err
894                         case pongTimeout:
895                                 return unknown, nil
896                         default:
897                                 return verifyinit, errInvalidEvent
898                         }
899                 },
900         }
901
902         verifywait = &nodeState{
903                 name: "verifywait",
904                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
905                         switch ev {
906                         case pingPacket:
907                                 net.handlePing(n, pkt)
908                                 return verifywait, nil
909                         case pongPacket:
910                                 err := net.handleKnownPong(n, pkt)
911                                 return known, err
912                         case pongTimeout:
913                                 return unknown, nil
914                         default:
915                                 return verifywait, errInvalidEvent
916                         }
917                 },
918         }
919
920         remoteverifywait = &nodeState{
921                 name: "remoteverifywait",
922                 enter: func(net *Network, n *Node) {
923                         net.timedEvent(respTimeout, n, pingTimeout)
924                 },
925                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
926                         switch ev {
927                         case pingPacket:
928                                 net.handlePing(n, pkt)
929                                 return remoteverifywait, nil
930                         case pingTimeout:
931                                 return known, nil
932                         default:
933                                 return remoteverifywait, errInvalidEvent
934                         }
935                 },
936         }
937
938         known = &nodeState{
939                 name:     "known",
940                 canQuery: true,
941                 enter: func(net *Network, n *Node) {
942                         n.queryTimeouts = 0
943                         n.startNextQuery(net)
944                         // Insert into the table and start revalidation of the last node
945                         // in the bucket if it is full.
946                         last := net.tab.add(n)
947                         if last != nil && last.state == known {
948                                 // TODO: do this asynchronously
949                                 net.transition(last, contested)
950                         }
951                 },
952                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
953                         if err := net.db.updateNode(n); err != nil {
954                                 return known, err
955                         }
956
957                         switch ev {
958                         case pingPacket:
959                                 net.handlePing(n, pkt)
960                                 return known, nil
961                         case pongPacket:
962                                 err := net.handleKnownPong(n, pkt)
963                                 return known, err
964                         default:
965                                 return net.handleQueryEvent(n, ev, pkt)
966                         }
967                 },
968         }
969
970         contested = &nodeState{
971                 name:     "contested",
972                 canQuery: true,
973                 enter: func(net *Network, n *Node) {
974                         n.pingEcho = nil
975                         net.ping(n, n.addr())
976                 },
977                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
978                         switch ev {
979                         case pongPacket:
980                                 // Node is still alive.
981                                 err := net.handleKnownPong(n, pkt)
982                                 return known, err
983                         case pongTimeout:
984                                 net.tab.deleteReplace(n)
985                                 return unresponsive, nil
986                         case pingPacket:
987                                 net.handlePing(n, pkt)
988                                 return contested, nil
989                         default:
990                                 return net.handleQueryEvent(n, ev, pkt)
991                         }
992                 },
993         }
994
995         unresponsive = &nodeState{
996                 name:     "unresponsive",
997                 canQuery: true,
998                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
999                         net.db.deleteNode(n.ID)
1000
1001                         switch ev {
1002                         case pingPacket:
1003                                 net.handlePing(n, pkt)
1004                                 return known, nil
1005                         case pongPacket:
1006                                 err := net.handleKnownPong(n, pkt)
1007                                 return known, err
1008                         default:
1009                                 return net.handleQueryEvent(n, ev, pkt)
1010                         }
1011                 },
1012         }
1013 }
1014
1015 // handle processes packets sent by n and events related to n.
1016 func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1017         //fmt.Println("handle", n.addr().String(), n.state, ev)
1018         if pkt != nil {
1019                 if err := net.checkPacket(n, ev, pkt); err != nil {
1020                         //fmt.Println("check err:", err)
1021                         return err
1022                 }
1023                 // Start the background expiration goroutine after the first
1024                 // successful communication. Subsequent calls have no effect if it
1025                 // is already running. We do this here instead of somewhere else
1026                 // so that the search for seed nodes also considers older nodes
1027                 // that would otherwise be removed by the expirer.
1028                 if net.db != nil {
1029                         net.db.ensureExpirer()
1030                 }
1031         }
1032         if n.state == nil {
1033                 n.state = unknown //???
1034         }
1035         next, err := n.state.handle(net, n, ev, pkt)
1036         net.transition(n, next)
1037         //fmt.Println("new state:", n.state)
1038         return err
1039 }
1040
1041 func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1042         // Replay prevention checks.
1043         switch ev {
1044         case pingPacket, findnodeHashPacket, neighborsPacket:
1045                 // TODO: check date is > last date seen
1046                 // TODO: check ping version
1047         case pongPacket:
1048                 if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) {
1049                         // fmt.Println("pong reply token mismatch")
1050                         return fmt.Errorf("pong reply token mismatch")
1051                 }
1052                 n.pingEcho = nil
1053         }
1054         // Address validation.
1055         // TODO: Ideally we would do the following:
1056         //  - reject all packets with wrong address except ping.
1057         //  - for ping with new address, transition to verifywait but keep the
1058         //    previous node (with old address) around. if the new one reaches known,
1059         //    swap it out.
1060         return nil
1061 }
1062
1063 func (net *Network) transition(n *Node, next *nodeState) {
1064         if n.state != next {
1065                 n.state = next
1066                 if next.enter != nil {
1067                         next.enter(net, n)
1068                 }
1069         }
1070
1071         // TODO: persist/unpersist node
1072 }
1073
1074 func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) {
1075         timeout := timeoutEvent{ev, n}
1076         net.timeoutTimers[timeout] = time.AfterFunc(d, func() {
1077                 select {
1078                 case net.timeout <- timeout:
1079                 case <-net.closed:
1080                 }
1081         })
1082 }
1083
1084 func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) {
1085         timer := net.timeoutTimers[timeoutEvent{ev, n}]
1086         if timer != nil {
1087                 timer.Stop()
1088                 delete(net.timeoutTimers, timeoutEvent{ev, n})
1089         }
1090 }
1091
1092 func (net *Network) ping(n *Node, addr *net.UDPAddr) {
1093         //fmt.Println("ping", n.addr().String(), n.ID.String(), n.sha.Hex())
1094         if n.pingEcho != nil || n.ID == net.tab.self.ID {
1095                 //fmt.Println(" not sent")
1096                 return
1097         }
1098         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Pinging remote node")
1099         n.pingTopics = net.ticketStore.regTopicSet()
1100         n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics)
1101         net.timedEvent(respTimeout, n, pongTimeout)
1102 }
1103
1104 func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
1105         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling remote ping")
1106         ping := pkt.data.(*ping)
1107         n.TCP = ping.From.TCP
1108         t := net.topictab.getTicket(n, ping.Topics)
1109
1110         pong := &pong{
1111                 To:         makeEndpoint(n.addr(), n.TCP), // TODO: maybe use known TCP port from DB
1112                 ReplyTok:   pkt.hash,
1113                 Expiration: uint64(time.Now().Add(expiration).Unix()),
1114         }
1115         ticketToPong(t, pong)
1116         net.conn.send(n, pongPacket, pong)
1117 }
1118
1119 func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
1120         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling known pong")
1121         net.abortTimedEvent(n, pongTimeout)
1122         now := Now()
1123         ticket, err := pongToTicket(now, n.pingTopics, n, pkt)
1124         if err == nil {
1125                 // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data)
1126                 net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket)
1127         } else {
1128                 log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Failed to convert pong to ticket")
1129         }
1130         n.pingEcho = nil
1131         n.pingTopics = nil
1132         net.db.updateLastPong(n.ID, time.Now())
1133         return err
1134 }
1135
1136 func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
1137         switch ev {
1138         case findnodePacket:
1139                 target := common.BytesToHash(pkt.data.(*findnode).Target[:])
1140                 results := net.tab.closest(target, bucketSize).entries
1141                 net.conn.sendNeighbours(n, results)
1142                 return n.state, nil
1143         case neighborsPacket:
1144                 err := net.handleNeighboursPacket(n, pkt)
1145                 return n.state, err
1146         case neighboursTimeout:
1147                 if n.pendingNeighbours != nil {
1148                         n.pendingNeighbours.reply <- nil
1149                         n.pendingNeighbours = nil
1150                 }
1151                 n.queryTimeouts++
1152                 if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1153                         return contested, errors.New("too many timeouts")
1154                 }
1155                 return n.state, nil
1156
1157         // v5
1158
1159         case findnodeHashPacket:
1160                 results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries
1161                 net.conn.sendNeighbours(n, results)
1162                 return n.state, nil
1163         case topicRegisterPacket:
1164                 //fmt.Println("got topicRegisterPacket")
1165                 regdata := pkt.data.(*topicRegister)
1166                 pong, err := net.checkTopicRegister(regdata, net.conn.getNetID())
1167                 if err != nil {
1168                         //fmt.Println(err)
1169                         return n.state, fmt.Errorf("bad waiting ticket: %v", err)
1170                 }
1171                 net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods)
1172                 return n.state, nil
1173         case topicQueryPacket:
1174                 // TODO: handle expiration
1175                 topic := pkt.data.(*topicQuery).Topic
1176                 results := net.topictab.getEntries(topic)
1177                 if _, ok := net.ticketStore.tickets[topic]; ok {
1178                         results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too
1179                 }
1180                 if len(results) > 10 {
1181                         results = results[:10]
1182                 }
1183                 var hash common.Hash
1184                 copy(hash[:], pkt.hash)
1185                 net.conn.sendTopicNodes(n, hash, results)
1186                 return n.state, nil
1187         case topicNodesPacket:
1188                 p := pkt.data.(*topicNodes)
1189                 if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) {
1190                         n.queryTimeouts++
1191                         if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1192                                 return contested, errors.New("too many timeouts")
1193                         }
1194                 }
1195                 return n.state, nil
1196
1197         default:
1198                 return n.state, errInvalidEvent
1199         }
1200 }
1201
1202 func (net *Network) checkTopicRegister(data *topicRegister, netID uint64) (*pong, error) {
1203         var pongpkt ingressPacket
1204         if err := decodePacket(data.Pong, &pongpkt, netID); err != nil {
1205                 return nil, err
1206         }
1207         if pongpkt.ev != pongPacket {
1208                 return nil, errors.New("is not pong packet")
1209         }
1210         if pongpkt.remoteID != net.tab.self.ID {
1211                 return nil, errors.New("not signed by us")
1212         }
1213         // check that we previously authorised all topics
1214         // that the other side is trying to register.
1215         hash, _, _ := wireHash(data.Topics)
1216         if hash != pongpkt.data.(*pong).TopicHash {
1217                 return nil, errors.New("topic hash mismatch")
1218         }
1219         if int(data.Idx) < 0 || int(data.Idx) >= len(data.Topics) {
1220                 return nil, errors.New("topic index out of range")
1221         }
1222         return pongpkt.data.(*pong), nil
1223 }
1224
1225 func wireHash(x interface{}) (h common.Hash, n int, err error) {
1226         hw := sha3.New256()
1227         wire.WriteBinary(x, hw, &n, &err)
1228         hw.Sum(h[:0])
1229         return h, n, err
1230 }
1231
1232 func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
1233         if n.pendingNeighbours == nil {
1234                 return errNoQuery
1235         }
1236         net.abortTimedEvent(n, neighboursTimeout)
1237
1238         req := pkt.data.(*neighbors)
1239         nodes := make([]*Node, len(req.Nodes))
1240         for i, rn := range req.Nodes {
1241                 nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
1242                 if err != nil {
1243                         log.WithFields(log.Fields{"module": logModule, "ip": rn.IP, "id:": n.ID[:8], "addr:": pkt.remoteAddr, "error": err}).Debug("invalid neighbour")
1244                         continue
1245                 }
1246                 nodes[i] = nn
1247                 // Start validation of query results immediately.
1248                 // This fills the table quickly.
1249                 // TODO: generates way too many packets, maybe do it via queue.
1250                 if nn.state == unknown {
1251                         net.transition(nn, verifyinit)
1252                 }
1253         }
1254         // TODO: don't ignore second packet
1255         n.pendingNeighbours.reply <- nodes
1256         n.pendingNeighbours = nil
1257         // Now that this query is done, start the next one.
1258         n.startNextQuery(net)
1259         return nil
1260 }