OSDN Git Service

c3f6a0ef739f4d69b3d7c59c69f010ccd010b0c9
[bytom/vapor.git] / p2p / discover / dht / net.go
1 package dht
2
3 import (
4         "bytes"
5         "encoding/hex"
6         "errors"
7         "fmt"
8         "net"
9         "time"
10
11         log "github.com/sirupsen/logrus"
12         "github.com/tendermint/go-wire"
13         "golang.org/x/crypto/sha3"
14
15         "github.com/vapor/common"
16         "github.com/vapor/crypto/sha3pool"
17         "github.com/vapor/p2p/netutil"
18         "github.com/vapor/p2p/signlib"
19 )
20
21 var (
22         errInvalidEvent = errors.New("invalid in current state")
23         errNoQuery      = errors.New("no pending query")
24         errWrongAddress = errors.New("unknown sender address")
25 )
26
27 const (
28         autoRefreshInterval   = 1 * time.Hour
29         bucketRefreshInterval = 1 * time.Minute
30         seedCount             = 30
31         seedMaxAge            = 5 * 24 * time.Hour
32         lowPort               = 1024
33 )
34
35 const (
36         printTestImgLogs = false
37 )
38
39 // Network manages the table and all protocol interaction.
40 type Network struct {
41         db          *nodeDB // database of known nodes
42         conn        transport
43         netrestrict *netutil.Netlist
44
45         closed           chan struct{}          // closed when loop is done
46         closeReq         chan struct{}          // 'request to close'
47         refreshReq       chan []*Node           // lookups ask for refresh on this channel
48         refreshResp      chan (<-chan struct{}) // ...and get the channel to block on from this one
49         read             chan ingressPacket     // ingress packets arrive here
50         timeout          chan timeoutEvent
51         queryReq         chan *findnodeQuery // lookups submit findnode queries on this channel
52         tableOpReq       chan func()
53         tableOpResp      chan struct{}
54         topicRegisterReq chan topicRegisterReq
55         topicSearchReq   chan topicSearchReq
56
57         // State of the main loop.
58         tab           *Table
59         topictab      *topicTable
60         ticketStore   *ticketStore
61         nursery       []*Node
62         nodes         map[NodeID]*Node // tracks active nodes with state != known
63         timeoutTimers map[timeoutEvent]*time.Timer
64
65         // Revalidation queues.
66         // Nodes put on these queues will be pinged eventually.
67         slowRevalidateQueue []*Node
68         fastRevalidateQueue []*Node
69
70         // Buffers for state transition.
71         sendBuf []*ingressPacket
72 }
73
74 // transport is implemented by the UDP transport.
75 // it is an interface so we can test without opening lots of UDP
76 // sockets and without generating a private key.
77 type transport interface {
78         sendPing(remote *Node, remoteAddr *net.UDPAddr, topics []Topic) (hash []byte)
79         sendNeighbours(remote *Node, nodes []*Node)
80         sendFindnodeHash(remote *Node, target common.Hash)
81         sendTopicRegister(remote *Node, topics []Topic, topicIdx int, pong []byte)
82         sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node)
83
84         send(remote *Node, ptype nodeEvent, p interface{}) (hash []byte)
85
86         localAddr() *net.UDPAddr
87         getNetID() uint64
88         Close()
89 }
90
91 type findnodeQuery struct {
92         remote   *Node
93         target   common.Hash
94         reply    chan<- []*Node
95         nresults int // counter for received nodes
96 }
97
98 type topicRegisterReq struct {
99         add   bool
100         topic Topic
101 }
102
103 type topicSearchReq struct {
104         topic  Topic
105         found  chan<- *Node
106         lookup chan<- bool
107         delay  time.Duration
108 }
109
110 type topicSearchResult struct {
111         target lookupInfo
112         nodes  []*Node
113 }
114
115 type timeoutEvent struct {
116         ev   nodeEvent
117         node *Node
118 }
119
120 func hash(target []byte) common.Hash {
121         var h [32]byte
122         sha3pool.Sum256(h[:], target)
123         return common.BytesToHash(h[:])
124 }
125
126 func newNetwork(conn transport, ourPubkey signlib.PubKey, dbPath string, netrestrict *netutil.Netlist) (*Network, error) {
127         var ourID NodeID
128         copy(ourID[:], ourPubkey.Bytes()[:nodeIDBits])
129
130         var db *nodeDB
131         if dbPath != "<no database>" {
132                 var err error
133                 if db, err = newNodeDB(dbPath, Version, ourID); err != nil {
134                         return nil, err
135                 }
136         }
137
138         tab := newTable(ourID, conn.localAddr())
139         net := &Network{
140                 db:               db,
141                 conn:             conn,
142                 netrestrict:      netrestrict,
143                 tab:              tab,
144                 topictab:         newTopicTable(db, tab.self),
145                 ticketStore:      newTicketStore(),
146                 refreshReq:       make(chan []*Node),
147                 refreshResp:      make(chan (<-chan struct{})),
148                 closed:           make(chan struct{}),
149                 closeReq:         make(chan struct{}),
150                 read:             make(chan ingressPacket, 100),
151                 timeout:          make(chan timeoutEvent),
152                 timeoutTimers:    make(map[timeoutEvent]*time.Timer),
153                 tableOpReq:       make(chan func()),
154                 tableOpResp:      make(chan struct{}),
155                 queryReq:         make(chan *findnodeQuery),
156                 topicRegisterReq: make(chan topicRegisterReq),
157                 topicSearchReq:   make(chan topicSearchReq),
158                 nodes:            make(map[NodeID]*Node),
159         }
160         go net.loop()
161         return net, nil
162 }
163
164 // Close terminates the network listener and flushes the node database.
165 func (net *Network) Close() {
166         net.conn.Close()
167         select {
168         case <-net.closed:
169         case net.closeReq <- struct{}{}:
170                 <-net.closed
171         }
172 }
173
174 // Self returns the local node.
175 // The returned node should not be modified by the caller.
176 func (net *Network) Self() *Node {
177         return net.tab.self
178 }
179
180 func (net *Network) selfIP() net.IP {
181         return net.tab.self.IP
182 }
183
184 // ReadRandomNodes fills the given slice with random nodes from the
185 // table. It will not write the same node more than once. The nodes in
186 // the slice are copies and can be modified by the caller.
187 func (net *Network) ReadRandomNodes(buf []*Node) (n int) {
188         net.reqTableOp(func() { n = net.tab.readRandomNodes(buf) })
189         return n
190 }
191
192 // SetFallbackNodes sets the initial points of contact. These nodes
193 // are used to connect to the network if the table is empty and there
194 // are no known nodes in the database.
195 func (net *Network) SetFallbackNodes(nodes []*Node) error {
196         nursery := make([]*Node, 0, len(nodes))
197         for _, n := range nodes {
198                 if err := n.validateComplete(); err != nil {
199                         return fmt.Errorf("bad bootstrap/fallback node %q (%v)", n, err)
200                 }
201                 // Recompute cpy.sha because the node might not have been
202                 // created by NewNode or ParseNode.
203                 cpy := *n
204                 cpy.sha = hash(n.ID[:])
205                 nursery = append(nursery, &cpy)
206         }
207         net.reqRefresh(nursery)
208         return nil
209 }
210
211 // Resolve searches for a specific node with the given ID.
212 // It returns nil if the node could not be found.
213 func (net *Network) Resolve(targetID NodeID) *Node {
214         result := net.lookup(hash(targetID[:]), true)
215         for _, n := range result {
216                 if n.ID == targetID {
217                         return n
218                 }
219         }
220         return nil
221 }
222
223 // Lookup performs a network search for nodes close
224 // to the given target. It approaches the target by querying
225 // nodes that are closer to it on each iteration.
226 // The given target does not need to be an actual node
227 // identifier.
228 //
229 // The local node may be included in the result.
230 func (net *Network) Lookup(targetID NodeID) []*Node {
231         return net.lookup(hash(targetID[:]), false)
232 }
233
234 func (net *Network) lookup(target common.Hash, stopOnMatch bool) []*Node {
235         var (
236                 asked          = make(map[NodeID]bool)
237                 seen           = make(map[NodeID]bool)
238                 reply          = make(chan []*Node, alpha)
239                 result         = nodesByDistance{target: target}
240                 pendingQueries = 0
241         )
242         // Get initial answers from the local node.
243         result.push(net.tab.self, bucketSize)
244         for {
245                 // Ask the Î± closest nodes that we haven't asked yet.
246                 for i := 0; i < len(result.entries) && pendingQueries < alpha; i++ {
247                         n := result.entries[i]
248                         if !asked[n.ID] {
249                                 asked[n.ID] = true
250                                 pendingQueries++
251                                 net.reqQueryFindnode(n, target, reply)
252                         }
253                 }
254                 if pendingQueries == 0 {
255                         // We have asked all closest nodes, stop the search.
256                         break
257                 }
258                 // Wait for the next reply.
259                 select {
260                 case nodes := <-reply:
261                         for _, n := range nodes {
262                                 if n != nil && !seen[n.ID] {
263                                         seen[n.ID] = true
264                                         result.push(n, bucketSize)
265                                         if stopOnMatch && n.sha == target {
266                                                 return result.entries
267                                         }
268                                 }
269                         }
270                         pendingQueries--
271                 case <-time.After(respTimeout):
272                         // forget all pending requests, start new ones
273                         pendingQueries = 0
274                         reply = make(chan []*Node, alpha)
275                 }
276         }
277         return result.entries
278 }
279
280 func (net *Network) RegisterTopic(topic Topic, stop <-chan struct{}) {
281         select {
282         case net.topicRegisterReq <- topicRegisterReq{true, topic}:
283         case <-net.closed:
284                 return
285         }
286         select {
287         case <-net.closed:
288         case <-stop:
289                 select {
290                 case net.topicRegisterReq <- topicRegisterReq{false, topic}:
291                 case <-net.closed:
292                 }
293         }
294 }
295
296 func (net *Network) SearchTopic(topic Topic, setPeriod <-chan time.Duration, found chan<- *Node, lookup chan<- bool) {
297         for {
298                 select {
299                 case <-net.closed:
300                         return
301                 case delay, ok := <-setPeriod:
302                         select {
303                         case net.topicSearchReq <- topicSearchReq{topic: topic, found: found, lookup: lookup, delay: delay}:
304                         case <-net.closed:
305                                 return
306                         }
307                         if !ok {
308                                 return
309                         }
310                 }
311         }
312 }
313
314 func (net *Network) reqRefresh(nursery []*Node) <-chan struct{} {
315         select {
316         case net.refreshReq <- nursery:
317                 return <-net.refreshResp
318         case <-net.closed:
319                 return net.closed
320         }
321 }
322
323 func (net *Network) reqQueryFindnode(n *Node, target common.Hash, reply chan []*Node) bool {
324         q := &findnodeQuery{remote: n, target: target, reply: reply}
325         select {
326         case net.queryReq <- q:
327                 return true
328         case <-net.closed:
329                 return false
330         }
331 }
332
333 func (net *Network) reqReadPacket(pkt ingressPacket) {
334         select {
335         case net.read <- pkt:
336         case <-net.closed:
337         }
338 }
339
340 func (net *Network) reqTableOp(f func()) (called bool) {
341         select {
342         case net.tableOpReq <- f:
343                 <-net.tableOpResp
344                 return true
345         case <-net.closed:
346                 return false
347         }
348 }
349
350 // TODO: external address handling.
351
352 type topicSearchInfo struct {
353         lookupChn chan<- bool
354         period    time.Duration
355 }
356
357 const maxSearchCount = 5
358
359 func (net *Network) loop() {
360         var (
361                 refreshTimer       = time.NewTicker(autoRefreshInterval)
362                 bucketRefreshTimer = time.NewTimer(bucketRefreshInterval)
363                 refreshDone        chan struct{} // closed when the 'refresh' lookup has ended
364         )
365
366         // Tracking the next ticket to register.
367         var (
368                 nextTicket        *ticketRef
369                 nextRegisterTimer *time.Timer
370                 nextRegisterTime  <-chan time.Time
371         )
372         defer func() {
373                 if nextRegisterTimer != nil {
374                         nextRegisterTimer.Stop()
375                 }
376                 refreshTimer.Stop()
377                 bucketRefreshTimer.Stop()
378         }()
379         resetNextTicket := func() {
380                 ticket, timeout := net.ticketStore.nextFilteredTicket()
381                 if nextTicket != ticket {
382                         nextTicket = ticket
383                         if nextRegisterTimer != nil {
384                                 nextRegisterTimer.Stop()
385                                 nextRegisterTime = nil
386                         }
387                         if ticket != nil {
388                                 nextRegisterTimer = time.NewTimer(timeout)
389                                 nextRegisterTime = nextRegisterTimer.C
390                         }
391                 }
392         }
393
394         // Tracking registration and search lookups.
395         var (
396                 topicRegisterLookupTarget lookupInfo
397                 topicRegisterLookupDone   chan []*Node
398                 topicRegisterLookupTick   = time.NewTimer(0)
399                 searchReqWhenRefreshDone  []topicSearchReq
400                 searchInfo                = make(map[Topic]topicSearchInfo)
401                 activeSearchCount         int
402         )
403         topicSearchLookupDone := make(chan topicSearchResult, 100)
404         topicSearch := make(chan Topic, 100)
405         <-topicRegisterLookupTick.C
406
407         statsDump := time.NewTicker(10 * time.Second)
408         defer statsDump.Stop()
409
410 loop:
411         for {
412                 resetNextTicket()
413
414                 select {
415                 case <-net.closeReq:
416                         log.WithFields(log.Fields{"module": logModule}).Debug("close request")
417                         break loop
418
419                 // Ingress packet handling.
420                 case pkt := <-net.read:
421                         log.WithFields(log.Fields{"module": logModule}).Debug("read from net")
422                         n := net.internNode(&pkt)
423                         prestate := n.state
424                         status := "ok"
425                         if err := net.handle(n, pkt.ev, &pkt); err != nil {
426                                 status = err.Error()
427                         }
428                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": pkt.ev, "remote id": hex.EncodeToString(pkt.remoteID[:8]), "remote addr": pkt.remoteAddr, "pre state": prestate, "node state": n.state, "status": status}).Debug("handle ingress msg")
429
430                         // TODO: persist state if n.state goes >= known, delete if it goes <= known
431
432                 // State transition timeouts.
433                 case timeout := <-net.timeout:
434                         log.WithFields(log.Fields{"module": logModule}).Debug("net timeout")
435                         if net.timeoutTimers[timeout] == nil {
436                                 // Stale timer (was aborted).
437                                 continue
438                         }
439                         delete(net.timeoutTimers, timeout)
440                         prestate := timeout.node.state
441                         status := "ok"
442                         if err := net.handle(timeout.node, timeout.ev, nil); err != nil {
443                                 status = err.Error()
444                         }
445                         log.WithFields(log.Fields{"module": logModule, "node num": net.tab.count, "event": timeout.ev, "node id": hex.EncodeToString(timeout.node.ID[:8]), "node addr": timeout.node.addr(), "pre state": prestate, "node state": timeout.node.state, "status": status}).Debug("handle timeout")
446
447                 // Querying.
448                 case q := <-net.queryReq:
449                         log.WithFields(log.Fields{"module": logModule}).Debug("net query request")
450                         if !q.start(net) {
451                                 q.remote.deferQuery(q)
452                         }
453
454                 // Interacting with the table.
455                 case f := <-net.tableOpReq:
456                         log.WithFields(log.Fields{"module": logModule}).Debug("net table operate request")
457                         f()
458                         net.tableOpResp <- struct{}{}
459
460                 // Topic registration stuff.
461                 case req := <-net.topicRegisterReq:
462                         log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic register request")
463                         if !req.add {
464                                 net.ticketStore.removeRegisterTopic(req.topic)
465                                 continue
466                         }
467                         net.ticketStore.addTopic(req.topic, true)
468                         // If we're currently waiting idle (nothing to look up), give the ticket store a
469                         // chance to start it sooner. This should speed up convergence of the radius
470                         // determination for new topics.
471                         // if topicRegisterLookupDone == nil {
472                         if topicRegisterLookupTarget.target == (common.Hash{}) {
473                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("topic register lookup target null")
474                                 if topicRegisterLookupTick.Stop() {
475                                         <-topicRegisterLookupTick.C
476                                 }
477                                 target, delay := net.ticketStore.nextRegisterLookup()
478                                 topicRegisterLookupTarget = target
479                                 topicRegisterLookupTick.Reset(delay)
480                         }
481
482                 case nodes := <-topicRegisterLookupDone:
483                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup done")
484                         net.ticketStore.registerLookupDone(topicRegisterLookupTarget, nodes, func(n *Node) []byte {
485                                 net.ping(n, n.addr())
486                                 return n.pingEcho
487                         })
488                         target, delay := net.ticketStore.nextRegisterLookup()
489                         topicRegisterLookupTarget = target
490                         topicRegisterLookupTick.Reset(delay)
491                         topicRegisterLookupDone = nil
492
493                 case <-topicRegisterLookupTick.C:
494                         log.WithFields(log.Fields{"module": logModule}).Debug("topic register lookup tick")
495                         if (topicRegisterLookupTarget.target == common.Hash{}) {
496                                 target, delay := net.ticketStore.nextRegisterLookup()
497                                 topicRegisterLookupTarget = target
498                                 topicRegisterLookupTick.Reset(delay)
499                                 topicRegisterLookupDone = nil
500                         } else {
501                                 topicRegisterLookupDone = make(chan []*Node)
502                                 target := topicRegisterLookupTarget.target
503                                 go func() { topicRegisterLookupDone <- net.lookup(target, false) }()
504                         }
505
506                 case <-nextRegisterTime:
507                         log.WithFields(log.Fields{"module": logModule}).Debug("next register time")
508                         net.ticketStore.ticketRegistered(*nextTicket)
509                         net.conn.sendTopicRegister(nextTicket.t.node, nextTicket.t.topics, nextTicket.idx, nextTicket.t.pong)
510
511                 case req := <-net.topicSearchReq:
512                         if refreshDone == nil {
513                                 log.WithFields(log.Fields{"module": logModule, "topic": req.topic}).Debug("net topic rearch req")
514                                 info, ok := searchInfo[req.topic]
515                                 if ok {
516                                         if req.delay == time.Duration(0) {
517                                                 delete(searchInfo, req.topic)
518                                                 net.ticketStore.removeSearchTopic(req.topic)
519                                         } else {
520                                                 info.period = req.delay
521                                                 searchInfo[req.topic] = info
522                                         }
523                                         continue
524                                 }
525                                 if req.delay != time.Duration(0) {
526                                         var info topicSearchInfo
527                                         info.period = req.delay
528                                         info.lookupChn = req.lookup
529                                         searchInfo[req.topic] = info
530                                         net.ticketStore.addSearchTopic(req.topic, req.found)
531                                         topicSearch <- req.topic
532                                 }
533                         } else {
534                                 searchReqWhenRefreshDone = append(searchReqWhenRefreshDone, req)
535                         }
536
537                 case topic := <-topicSearch:
538                         if activeSearchCount < maxSearchCount {
539                                 activeSearchCount++
540                                 target := net.ticketStore.nextSearchLookup(topic)
541                                 go func() {
542                                         nodes := net.lookup(target.target, false)
543                                         topicSearchLookupDone <- topicSearchResult{target: target, nodes: nodes}
544                                 }()
545                         }
546                         period := searchInfo[topic].period
547                         if period != time.Duration(0) {
548                                 go func() {
549                                         time.Sleep(period)
550                                         topicSearch <- topic
551                                 }()
552                         }
553
554                 case res := <-topicSearchLookupDone:
555                         activeSearchCount--
556                         if lookupChn := searchInfo[res.target.topic].lookupChn; lookupChn != nil {
557                                 lookupChn <- net.ticketStore.radius[res.target.topic].converged
558                         }
559                         net.ticketStore.searchLookupDone(res.target, res.nodes, func(n *Node, topic Topic) []byte {
560                                 if n.state != nil && n.state.canQuery {
561                                         return net.conn.send(n, topicQueryPacket, topicQuery{Topic: topic}) // TODO: set expiration
562                                 } else {
563                                         if n.state == unknown {
564                                                 net.ping(n, n.addr())
565                                         }
566                                         return nil
567                                 }
568                         })
569
570                 case <-statsDump.C:
571                         log.WithFields(log.Fields{"module": logModule}).Debug("stats dump clock")
572                         /*r, ok := net.ticketStore.radius[testTopic]
573                         if !ok {
574                                 fmt.Printf("(%x) no radius @ %v\n", net.tab.self.ID[:8], time.Now())
575                         } else {
576                                 topics := len(net.ticketStore.tickets)
577                                 tickets := len(net.ticketStore.nodes)
578                                 rad := r.radius / (maxRadius/10000+1)
579                                 fmt.Printf("(%x) topics:%d radius:%d tickets:%d @ %v\n", net.tab.self.ID[:8], topics, rad, tickets, time.Now())
580                         }*/
581
582                         tm := Now()
583                         for topic, r := range net.ticketStore.radius {
584                                 if printTestImgLogs {
585                                         rad := r.radius / (maxRadius/1000000 + 1)
586                                         minrad := r.minRadius / (maxRadius/1000000 + 1)
587                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*R %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], rad)
588                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*MR %d %v %016x %v\n", tm/1000000, topic, net.tab.self.sha[:8], minrad)
589                                 }
590                         }
591                         for topic, t := range net.topictab.topics {
592                                 wp := t.wcl.nextWaitPeriod(tm)
593                                 if printTestImgLogs {
594                                         log.WithFields(log.Fields{"module": logModule}).Debugf("*W %d %v %016x %d\n", tm/1000000, topic, net.tab.self.sha[:8], wp/1000000)
595                                 }
596                         }
597
598                 // Periodic / lookup-initiated bucket refresh.
599                 case <-refreshTimer.C:
600                         log.WithFields(log.Fields{"module": logModule}).Debug("refresh timer clock")
601                         // TODO: ideally we would start the refresh timer after
602                         // fallback nodes have been set for the first time.
603                         if refreshDone == nil {
604                                 refreshDone = make(chan struct{})
605                                 net.refresh(refreshDone)
606                         }
607                 case <-bucketRefreshTimer.C:
608                         target := net.tab.chooseBucketRefreshTarget()
609                         go func() {
610                                 net.lookup(target, false)
611                                 bucketRefreshTimer.Reset(bucketRefreshInterval)
612                         }()
613                 case newNursery := <-net.refreshReq:
614                         log.WithFields(log.Fields{"module": logModule}).Debug("net refresh request")
615                         if newNursery != nil {
616                                 net.nursery = newNursery
617                         }
618                         if refreshDone == nil {
619                                 refreshDone = make(chan struct{})
620                                 net.refresh(refreshDone)
621                         }
622                         net.refreshResp <- refreshDone
623                 case <-refreshDone:
624                         log.WithFields(log.Fields{"module": logModule, "table size": net.tab.count}).Debug("net refresh done")
625                         if net.tab.count != 0 {
626                                 refreshDone = nil
627                                 list := searchReqWhenRefreshDone
628                                 searchReqWhenRefreshDone = nil
629                                 go func() {
630                                         for _, req := range list {
631                                                 net.topicSearchReq <- req
632                                         }
633                                 }()
634                         } else {
635                                 refreshDone = make(chan struct{})
636                                 net.refresh(refreshDone)
637                         }
638                 }
639         }
640         log.WithFields(log.Fields{"module": logModule}).Debug("loop stopped,shutting down")
641         if net.conn != nil {
642                 net.conn.Close()
643         }
644         if refreshDone != nil {
645                 // TODO: wait for pending refresh.
646                 //<-refreshResults
647         }
648         // Cancel all pending timeouts.
649         for _, timer := range net.timeoutTimers {
650                 timer.Stop()
651         }
652         if net.db != nil {
653                 net.db.close()
654         }
655         close(net.closed)
656 }
657
658 // Everything below runs on the Network.loop goroutine
659 // and can modify Node, Table and Network at any time without locking.
660
661 func (net *Network) refresh(done chan<- struct{}) {
662         var seeds []*Node
663         if net.db != nil {
664                 seeds = net.db.querySeeds(seedCount, seedMaxAge)
665         }
666         if len(seeds) == 0 {
667                 seeds = net.nursery
668         }
669         if len(seeds) == 0 {
670                 log.WithFields(log.Fields{"module": logModule}).Debug("no seed nodes found")
671                 time.AfterFunc(time.Second*10, func() { close(done) })
672                 return
673         }
674         for _, n := range seeds {
675                 n = net.internNodeFromDB(n)
676                 if n.state == unknown {
677                         net.transition(n, verifyinit)
678                 }
679                 // Force-add the seed node so Lookup does something.
680                 // It will be deleted again if verification fails.
681                 net.tab.add(n)
682         }
683         // Start self lookup to fill up the buckets.
684         go func() {
685                 net.Lookup(net.tab.self.ID)
686                 close(done)
687         }()
688 }
689
690 // Node Interning.
691
692 func (net *Network) internNode(pkt *ingressPacket) *Node {
693         if n := net.nodes[pkt.remoteID]; n != nil {
694                 n.IP = pkt.remoteAddr.IP
695                 n.UDP = uint16(pkt.remoteAddr.Port)
696                 n.TCP = uint16(pkt.remoteAddr.Port)
697                 return n
698         }
699         n := NewNode(pkt.remoteID, pkt.remoteAddr.IP, uint16(pkt.remoteAddr.Port), uint16(pkt.remoteAddr.Port))
700         n.state = unknown
701         net.nodes[pkt.remoteID] = n
702         return n
703 }
704
705 func (net *Network) internNodeFromDB(dbn *Node) *Node {
706         if n := net.nodes[dbn.ID]; n != nil {
707                 return n
708         }
709         n := NewNode(dbn.ID, dbn.IP, dbn.UDP, dbn.TCP)
710         n.state = unknown
711         net.nodes[n.ID] = n
712         return n
713 }
714
715 func (net *Network) internNodeFromNeighbours(sender *net.UDPAddr, rn rpcNode) (n *Node, err error) {
716         if rn.ID == net.tab.self.ID {
717                 return nil, errors.New("is self")
718         }
719         if rn.UDP <= lowPort {
720                 return nil, errors.New("low port")
721         }
722         n = net.nodes[rn.ID]
723         if n == nil {
724                 // We haven't seen this node before.
725                 n, err = nodeFromRPC(sender, rn)
726                 if net.netrestrict != nil && !net.netrestrict.Contains(n.IP) {
727                         return n, errors.New("not contained in netrestrict whitelist")
728                 }
729                 if err == nil {
730                         n.state = unknown
731                         net.nodes[n.ID] = n
732                 }
733                 return n, err
734         }
735         if !n.IP.Equal(rn.IP) || n.UDP != rn.UDP || n.TCP != rn.TCP {
736                 if n.state == known {
737                         // reject address change if node is known by us
738                         err = fmt.Errorf("metadata mismatch: got %v, want %v", rn, n)
739                 } else {
740                         // accept otherwise; this will be handled nicer with signed ENRs
741                         n.IP = rn.IP
742                         n.UDP = rn.UDP
743                         n.TCP = rn.TCP
744                 }
745         }
746         return n, err
747 }
748
749 // nodeNetGuts is embedded in Node and contains fields.
750 type nodeNetGuts struct {
751         // This is a cached copy of sha3(ID) which is used for node
752         // distance calculations. This is part of Node in order to make it
753         // possible to write tests that need a node at a certain distance.
754         // In those tests, the content of sha will not actually correspond
755         // with ID.
756         sha common.Hash
757
758         // State machine fields. Access to these fields
759         // is restricted to the Network.loop goroutine.
760         state             *nodeState
761         pingEcho          []byte           // hash of last ping sent by us
762         pingTopics        []Topic          // topic set sent by us in last ping
763         deferredQueries   []*findnodeQuery // queries that can't be sent yet
764         pendingNeighbours *findnodeQuery   // current query, waiting for reply
765         queryTimeouts     int
766 }
767
768 func (n *nodeNetGuts) deferQuery(q *findnodeQuery) {
769         n.deferredQueries = append(n.deferredQueries, q)
770 }
771
772 func (n *nodeNetGuts) startNextQuery(net *Network) {
773         if len(n.deferredQueries) == 0 {
774                 return
775         }
776         nextq := n.deferredQueries[0]
777         if nextq.start(net) {
778                 n.deferredQueries = append(n.deferredQueries[:0], n.deferredQueries[1:]...)
779         }
780 }
781
782 func (q *findnodeQuery) start(net *Network) bool {
783         // Satisfy queries against the local node directly.
784         if q.remote == net.tab.self {
785                 log.WithFields(log.Fields{"module": logModule}).Debug("findnodeQuery self")
786                 closest := net.tab.closest(q.target, bucketSize)
787
788                 q.reply <- closest.entries
789                 return true
790         }
791         if q.remote.state.canQuery && q.remote.pendingNeighbours == nil {
792                 log.WithFields(log.Fields{"module": logModule, "remote peer": q.remote.ID, "targetID": q.target}).Debug("find node query")
793                 net.conn.sendFindnodeHash(q.remote, q.target)
794                 net.timedEvent(respTimeout, q.remote, neighboursTimeout)
795                 q.remote.pendingNeighbours = q
796                 return true
797         }
798         // If the node is not known yet, it won't accept queries.
799         // Initiate the transition to known.
800         // The request will be sent later when the node reaches known state.
801         if q.remote.state == unknown {
802                 log.WithFields(log.Fields{"module": logModule, "id": q.remote.ID, "status": "unknown->verify init"}).Debug("find node query")
803                 net.transition(q.remote, verifyinit)
804         }
805         return false
806 }
807
808 // Node Events (the input to the state machine).
809
810 type nodeEvent uint
811
812 //go:generate stringer -type=nodeEvent
813
814 const (
815         invalidEvent nodeEvent = iota // zero is reserved
816
817         // Packet type events.
818         // These correspond to packet types in the UDP protocol.
819         pingPacket
820         pongPacket
821         findnodePacket
822         neighborsPacket
823         findnodeHashPacket
824         topicRegisterPacket
825         topicQueryPacket
826         topicNodesPacket
827
828         // Non-packet events.
829         // Event values in this category are allocated outside
830         // the packet type range (packet types are encoded as a single byte).
831         pongTimeout nodeEvent = iota + 256
832         pingTimeout
833         neighboursTimeout
834 )
835
836 // Node State Machine.
837
838 type nodeState struct {
839         name     string
840         handle   func(*Network, *Node, nodeEvent, *ingressPacket) (next *nodeState, err error)
841         enter    func(*Network, *Node)
842         canQuery bool
843 }
844
845 func (s *nodeState) String() string {
846         return s.name
847 }
848
849 var (
850         unknown          *nodeState
851         verifyinit       *nodeState
852         verifywait       *nodeState
853         remoteverifywait *nodeState
854         known            *nodeState
855         contested        *nodeState
856         unresponsive     *nodeState
857 )
858
859 func init() {
860         unknown = &nodeState{
861                 name: "unknown",
862                 enter: func(net *Network, n *Node) {
863                         net.tab.delete(n)
864                         n.pingEcho = nil
865                         // Abort active queries.
866                         for _, q := range n.deferredQueries {
867                                 q.reply <- nil
868                         }
869                         n.deferredQueries = nil
870                         if n.pendingNeighbours != nil {
871                                 n.pendingNeighbours.reply <- nil
872                                 n.pendingNeighbours = nil
873                         }
874                         n.queryTimeouts = 0
875                 },
876                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
877                         switch ev {
878                         case pingPacket:
879                                 net.handlePing(n, pkt)
880                                 net.ping(n, pkt.remoteAddr)
881                                 return verifywait, nil
882                         default:
883                                 return unknown, errInvalidEvent
884                         }
885                 },
886         }
887
888         verifyinit = &nodeState{
889                 name: "verifyinit",
890                 enter: func(net *Network, n *Node) {
891                         net.ping(n, n.addr())
892                 },
893                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
894                         switch ev {
895                         case pingPacket:
896                                 net.handlePing(n, pkt)
897                                 return verifywait, nil
898                         case pongPacket:
899                                 err := net.handleKnownPong(n, pkt)
900                                 return remoteverifywait, err
901                         case pongTimeout:
902                                 return unknown, nil
903                         default:
904                                 return verifyinit, errInvalidEvent
905                         }
906                 },
907         }
908
909         verifywait = &nodeState{
910                 name: "verifywait",
911                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
912                         switch ev {
913                         case pingPacket:
914                                 net.handlePing(n, pkt)
915                                 return verifywait, nil
916                         case pongPacket:
917                                 err := net.handleKnownPong(n, pkt)
918                                 return known, err
919                         case pongTimeout:
920                                 return unknown, nil
921                         default:
922                                 return verifywait, errInvalidEvent
923                         }
924                 },
925         }
926
927         remoteverifywait = &nodeState{
928                 name: "remoteverifywait",
929                 enter: func(net *Network, n *Node) {
930                         net.timedEvent(respTimeout, n, pingTimeout)
931                 },
932                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
933                         switch ev {
934                         case pingPacket:
935                                 net.handlePing(n, pkt)
936                                 return remoteverifywait, nil
937                         case pingTimeout:
938                                 return known, nil
939                         default:
940                                 return remoteverifywait, errInvalidEvent
941                         }
942                 },
943         }
944
945         known = &nodeState{
946                 name:     "known",
947                 canQuery: true,
948                 enter: func(net *Network, n *Node) {
949                         n.queryTimeouts = 0
950                         n.startNextQuery(net)
951                         // Insert into the table and start revalidation of the last node
952                         // in the bucket if it is full.
953                         last := net.tab.add(n)
954                         if last != nil && last.state == known {
955                                 // TODO: do this asynchronously
956                                 net.transition(last, contested)
957                         }
958                 },
959                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
960                         if err := net.db.updateNode(n); err != nil {
961                                 return known, err
962                         }
963
964                         switch ev {
965                         case pingPacket:
966                                 net.handlePing(n, pkt)
967                                 return known, nil
968                         case pongPacket:
969                                 err := net.handleKnownPong(n, pkt)
970                                 return known, err
971                         default:
972                                 return net.handleQueryEvent(n, ev, pkt)
973                         }
974                 },
975         }
976
977         contested = &nodeState{
978                 name:     "contested",
979                 canQuery: true,
980                 enter: func(net *Network, n *Node) {
981                         n.pingEcho = nil
982                         net.ping(n, n.addr())
983                 },
984                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
985                         switch ev {
986                         case pongPacket:
987                                 // Node is still alive.
988                                 err := net.handleKnownPong(n, pkt)
989                                 return known, err
990                         case pongTimeout:
991                                 net.tab.deleteReplace(n)
992                                 return unresponsive, nil
993                         case pingPacket:
994                                 net.handlePing(n, pkt)
995                                 return contested, nil
996                         default:
997                                 return net.handleQueryEvent(n, ev, pkt)
998                         }
999                 },
1000         }
1001
1002         unresponsive = &nodeState{
1003                 name:     "unresponsive",
1004                 canQuery: true,
1005                 handle: func(net *Network, n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
1006                         net.db.deleteNode(n.ID)
1007
1008                         switch ev {
1009                         case pingPacket:
1010                                 net.handlePing(n, pkt)
1011                                 return known, nil
1012                         case pongPacket:
1013                                 err := net.handleKnownPong(n, pkt)
1014                                 return known, err
1015                         default:
1016                                 return net.handleQueryEvent(n, ev, pkt)
1017                         }
1018                 },
1019         }
1020 }
1021
1022 // handle processes packets sent by n and events related to n.
1023 func (net *Network) handle(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1024         //fmt.Println("handle", n.addr().String(), n.state, ev)
1025         if pkt != nil {
1026                 if err := net.checkPacket(n, ev, pkt); err != nil {
1027                         //fmt.Println("check err:", err)
1028                         return err
1029                 }
1030                 // Start the background expiration goroutine after the first
1031                 // successful communication. Subsequent calls have no effect if it
1032                 // is already running. We do this here instead of somewhere else
1033                 // so that the search for seed nodes also considers older nodes
1034                 // that would otherwise be removed by the expirer.
1035                 if net.db != nil {
1036                         net.db.ensureExpirer()
1037                 }
1038         }
1039         if n.state == nil {
1040                 n.state = unknown //???
1041         }
1042         next, err := n.state.handle(net, n, ev, pkt)
1043         net.transition(n, next)
1044         //fmt.Println("new state:", n.state)
1045         return err
1046 }
1047
1048 func (net *Network) checkPacket(n *Node, ev nodeEvent, pkt *ingressPacket) error {
1049         // Replay prevention checks.
1050         switch ev {
1051         case pingPacket, findnodeHashPacket, neighborsPacket:
1052                 // TODO: check date is > last date seen
1053                 // TODO: check ping version
1054         case pongPacket:
1055                 if !bytes.Equal(pkt.data.(*pong).ReplyTok, n.pingEcho) {
1056                         // fmt.Println("pong reply token mismatch")
1057                         return fmt.Errorf("pong reply token mismatch")
1058                 }
1059                 n.pingEcho = nil
1060         }
1061         // Address validation.
1062         // TODO: Ideally we would do the following:
1063         //  - reject all packets with wrong address except ping.
1064         //  - for ping with new address, transition to verifywait but keep the
1065         //    previous node (with old address) around. if the new one reaches known,
1066         //    swap it out.
1067         return nil
1068 }
1069
1070 func (net *Network) transition(n *Node, next *nodeState) {
1071         if n.state != next {
1072                 n.state = next
1073                 if next.enter != nil {
1074                         next.enter(net, n)
1075                 }
1076         }
1077
1078         // TODO: persist/unpersist node
1079 }
1080
1081 func (net *Network) timedEvent(d time.Duration, n *Node, ev nodeEvent) {
1082         timeout := timeoutEvent{ev, n}
1083         net.timeoutTimers[timeout] = time.AfterFunc(d, func() {
1084                 select {
1085                 case net.timeout <- timeout:
1086                 case <-net.closed:
1087                 }
1088         })
1089 }
1090
1091 func (net *Network) abortTimedEvent(n *Node, ev nodeEvent) {
1092         timer := net.timeoutTimers[timeoutEvent{ev, n}]
1093         if timer != nil {
1094                 timer.Stop()
1095                 delete(net.timeoutTimers, timeoutEvent{ev, n})
1096         }
1097 }
1098
1099 func (net *Network) ping(n *Node, addr *net.UDPAddr) {
1100         //fmt.Println("ping", n.addr().String(), n.ID.String(), n.sha.Hex())
1101         if n.pingEcho != nil || n.ID == net.tab.self.ID {
1102                 //fmt.Println(" not sent")
1103                 return
1104         }
1105         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Pinging remote node")
1106         n.pingTopics = net.ticketStore.regTopicSet()
1107         n.pingEcho = net.conn.sendPing(n, addr, n.pingTopics)
1108         net.timedEvent(respTimeout, n, pongTimeout)
1109 }
1110
1111 func (net *Network) handlePing(n *Node, pkt *ingressPacket) {
1112         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling remote ping")
1113         ping := pkt.data.(*ping)
1114         n.TCP = ping.From.TCP
1115         t := net.topictab.getTicket(n, ping.Topics)
1116
1117         pong := &pong{
1118                 To:         makeEndpoint(n.addr(), n.TCP), // TODO: maybe use known TCP port from DB
1119                 ReplyTok:   pkt.hash,
1120                 Expiration: uint64(time.Now().Add(expiration).Unix()),
1121         }
1122         ticketToPong(t, pong)
1123         net.conn.send(n, pongPacket, pong)
1124 }
1125
1126 func (net *Network) handleKnownPong(n *Node, pkt *ingressPacket) error {
1127         log.WithFields(log.Fields{"module": logModule, "node": n.ID}).Debug("Handling known pong")
1128         net.abortTimedEvent(n, pongTimeout)
1129         now := Now()
1130         ticket, err := pongToTicket(now, n.pingTopics, n, pkt)
1131         if err == nil {
1132                 // fmt.Printf("(%x) ticket: %+v\n", net.tab.self.ID[:8], pkt.data)
1133                 net.ticketStore.addTicket(now, pkt.data.(*pong).ReplyTok, ticket)
1134         } else {
1135                 log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Failed to convert pong to ticket")
1136         }
1137         n.pingEcho = nil
1138         n.pingTopics = nil
1139         net.db.updateLastPong(n.ID, time.Now())
1140         return err
1141 }
1142
1143 func (net *Network) handleQueryEvent(n *Node, ev nodeEvent, pkt *ingressPacket) (*nodeState, error) {
1144         switch ev {
1145         case findnodePacket:
1146                 results := net.tab.closest(hash(pkt.data.(*findnode).Target[:]), bucketSize).entries
1147                 net.conn.sendNeighbours(n, results)
1148                 return n.state, nil
1149         case neighborsPacket:
1150                 err := net.handleNeighboursPacket(n, pkt)
1151                 return n.state, err
1152         case neighboursTimeout:
1153                 if n.pendingNeighbours != nil {
1154                         n.pendingNeighbours.reply <- nil
1155                         n.pendingNeighbours = nil
1156                 }
1157                 n.queryTimeouts++
1158                 if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1159                         return contested, errors.New("too many timeouts")
1160                 }
1161                 return n.state, nil
1162
1163         // v5
1164
1165         case findnodeHashPacket:
1166                 results := net.tab.closest(pkt.data.(*findnodeHash).Target, bucketSize).entries
1167                 net.conn.sendNeighbours(n, results)
1168                 return n.state, nil
1169         case topicRegisterPacket:
1170                 //fmt.Println("got topicRegisterPacket")
1171                 regdata := pkt.data.(*topicRegister)
1172                 pong, err := net.checkTopicRegister(regdata, net.conn.getNetID())
1173                 if err != nil {
1174                         //fmt.Println(err)
1175                         return n.state, fmt.Errorf("bad waiting ticket: %v", err)
1176                 }
1177                 net.topictab.useTicket(n, pong.TicketSerial, regdata.Topics, int(regdata.Idx), pong.Expiration, pong.WaitPeriods)
1178                 return n.state, nil
1179         case topicQueryPacket:
1180                 // TODO: handle expiration
1181                 topic := pkt.data.(*topicQuery).Topic
1182                 results := net.topictab.getEntries(topic)
1183                 if _, ok := net.ticketStore.tickets[topic]; ok {
1184                         results = append(results, net.tab.self) // we're not registering in our own table but if we're advertising, return ourselves too
1185                 }
1186                 if len(results) > 10 {
1187                         results = results[:10]
1188                 }
1189                 var hash common.Hash
1190                 copy(hash[:], pkt.hash)
1191                 net.conn.sendTopicNodes(n, hash, results)
1192                 return n.state, nil
1193         case topicNodesPacket:
1194                 p := pkt.data.(*topicNodes)
1195                 if net.ticketStore.gotTopicNodes(n, p.Echo, p.Nodes) {
1196                         n.queryTimeouts++
1197                         if n.queryTimeouts > maxFindnodeFailures && n.state == known {
1198                                 return contested, errors.New("too many timeouts")
1199                         }
1200                 }
1201                 return n.state, nil
1202
1203         default:
1204                 return n.state, errInvalidEvent
1205         }
1206 }
1207
1208 func (net *Network) checkTopicRegister(data *topicRegister, netID uint64) (*pong, error) {
1209         var pongpkt ingressPacket
1210         if err := decodePacket(data.Pong, &pongpkt, netID); err != nil {
1211                 return nil, err
1212         }
1213         if pongpkt.ev != pongPacket {
1214                 return nil, errors.New("is not pong packet")
1215         }
1216         if pongpkt.remoteID != net.tab.self.ID {
1217                 return nil, errors.New("not signed by us")
1218         }
1219         // check that we previously authorised all topics
1220         // that the other side is trying to register.
1221         hash, _, _ := wireHash(data.Topics)
1222         if hash != pongpkt.data.(*pong).TopicHash {
1223                 return nil, errors.New("topic hash mismatch")
1224         }
1225         if int(data.Idx) < 0 || int(data.Idx) >= len(data.Topics) {
1226                 return nil, errors.New("topic index out of range")
1227         }
1228         return pongpkt.data.(*pong), nil
1229 }
1230
1231 func wireHash(x interface{}) (h common.Hash, n int, err error) {
1232         hw := sha3.New256()
1233         wire.WriteBinary(x, hw, &n, &err)
1234         hw.Sum(h[:0])
1235         return h, n, err
1236 }
1237
1238 func (net *Network) handleNeighboursPacket(n *Node, pkt *ingressPacket) error {
1239         if n.pendingNeighbours == nil {
1240                 return errNoQuery
1241         }
1242         net.abortTimedEvent(n, neighboursTimeout)
1243
1244         req := pkt.data.(*neighbors)
1245         nodes := make([]*Node, len(req.Nodes))
1246         for i, rn := range req.Nodes {
1247                 nn, err := net.internNodeFromNeighbours(pkt.remoteAddr, rn)
1248                 if err != nil {
1249                         log.WithFields(log.Fields{"module": logModule, "ip": rn.IP, "id:": n.ID[:8], "addr:": pkt.remoteAddr, "error": err}).Debug("invalid neighbour")
1250                         continue
1251                 }
1252                 nodes[i] = nn
1253                 // Start validation of query results immediately.
1254                 // This fills the table quickly.
1255                 // TODO: generates way too many packets, maybe do it via queue.
1256                 if nn.state == unknown {
1257                         net.transition(nn, verifyinit)
1258                 }
1259         }
1260         // TODO: don't ignore second packet
1261         n.pendingNeighbours.reply <- nodes
1262         n.pendingNeighbours = nil
1263         // Now that this query is done, start the next one.
1264         n.startNextQuery(net)
1265         return nil
1266 }