OSDN Git Service

91c6da93db6e6ce40c6555444e64e12a1fe68171
[bytom/vapor.git] / p2p / discover / dht / node.go
1 package dht
2
3 import (
4         "crypto/ecdsa"
5         "crypto/elliptic"
6         "encoding/hex"
7         "errors"
8         "fmt"
9         "math/rand"
10         "net"
11         "net/url"
12         "regexp"
13         "strconv"
14         "strings"
15         "time"
16
17         "github.com/vapor/common"
18         "github.com/vapor/crypto"
19 )
20
21 // Node represents a host on the network.
22 // The public fields of Node may not be modified.
23 type Node struct {
24         IP       net.IP // len 4 for IPv4 or 16 for IPv6
25         UDP, TCP uint16 // port numbers
26         ID       NodeID // the node's public key
27
28         // Network-related fields are contained in nodeNetGuts.
29         // These fields are not supposed to be used off the
30         // Network.loop goroutine.
31         nodeNetGuts
32 }
33
34 // NewNode creates a new node. It is mostly meant to be used for
35 // testing purposes.
36 func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
37         if ipv4 := ip.To4(); ipv4 != nil {
38                 ip = ipv4
39         }
40         return &Node{
41                 IP:          ip,
42                 UDP:         udpPort,
43                 TCP:         tcpPort,
44                 ID:          id,
45                 nodeNetGuts: nodeNetGuts{sha: crypto.Sha256Hash(id[:])},
46         }
47 }
48
49 func (n *Node) addr() *net.UDPAddr {
50         return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
51 }
52
53 func (n *Node) setAddr(a *net.UDPAddr) {
54         n.IP = a.IP
55         if ipv4 := a.IP.To4(); ipv4 != nil {
56                 n.IP = ipv4
57         }
58         n.UDP = uint16(a.Port)
59 }
60
61 // compares the given address against the stored values.
62 func (n *Node) addrEqual(a *net.UDPAddr) bool {
63         ip := a.IP
64         if ipv4 := a.IP.To4(); ipv4 != nil {
65                 ip = ipv4
66         }
67         return n.UDP == uint16(a.Port) && n.IP.Equal(ip)
68 }
69
70 // Incomplete returns true for nodes with no IP address.
71 func (n *Node) Incomplete() bool {
72         return n.IP == nil
73 }
74
75 // checks whether n is a valid complete node.
76 func (n *Node) validateComplete() error {
77         if n.Incomplete() {
78                 return errors.New("incomplete node")
79         }
80         if n.UDP == 0 {
81                 return errors.New("missing UDP port")
82         }
83         if n.TCP == 0 {
84                 return errors.New("missing TCP port")
85         }
86         if n.IP.IsMulticast() || n.IP.IsUnspecified() {
87                 return errors.New("invalid IP (multicast/unspecified)")
88         }
89         //_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
90         return nil
91 }
92
93 // The string representation of a Node is a URL.
94 // Please see ParseNode for a description of the format.
95 func (n *Node) String() string {
96         u := url.URL{Scheme: "enode"}
97         if n.Incomplete() {
98                 u.Host = fmt.Sprintf("%x", n.ID[:])
99         } else {
100                 addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
101                 u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
102                 u.Host = addr.String()
103                 if n.UDP != n.TCP {
104                         u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
105                 }
106         }
107         return u.String()
108 }
109
110 var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
111
112 // ParseNode parses a node designator.
113 //
114 // There are two basic forms of node designators
115 //   - incomplete nodes, which only have the public key (node ID)
116 //   - complete nodes, which contain the public key and IP/Port information
117 //
118 // For incomplete nodes, the designator must look like one of these
119 //
120 //    enode://<hex node id>
121 //    <hex node id>
122 //
123 // For complete nodes, the node ID is encoded in the username portion
124 // of the URL, separated from the host by an @ sign. The hostname can
125 // only be given as an IP address, DNS domain names are not allowed.
126 // The port in the host name section is the TCP listening port. If the
127 // TCP and UDP (discovery) ports differ, the UDP port is specified as
128 // query parameter "discport".
129 //
130 // In the following example, the node URL describes
131 // a node with IP address 10.3.58.6, TCP listening port 30303
132 // and UDP discovery port 30301.
133 //
134 //    enode://<hex node id>@10.3.58.6:30303?discport=30301
135 func ParseNode(rawurl string) (*Node, error) {
136         if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
137                 id, err := HexID(m[1])
138                 if err != nil {
139                         return nil, fmt.Errorf("invalid node ID (%v)", err)
140                 }
141                 return NewNode(id, nil, 0, 0), nil
142         }
143         return parseComplete(rawurl)
144 }
145
146 func parseComplete(rawurl string) (*Node, error) {
147         var (
148                 id               NodeID
149                 ip               net.IP
150                 tcpPort, udpPort uint64
151         )
152         u, err := url.Parse(rawurl)
153         if err != nil {
154                 return nil, err
155         }
156         if u.Scheme != "enode" {
157                 return nil, errors.New("invalid URL scheme, want \"enode\"")
158         }
159         // Parse the Node ID from the user portion.
160         if u.User == nil {
161                 return nil, errors.New("does not contain node ID")
162         }
163         if id, err = HexID(u.User.String()); err != nil {
164                 return nil, fmt.Errorf("invalid node ID (%v)", err)
165         }
166         // Parse the IP address.
167         host, port, err := net.SplitHostPort(u.Host)
168         if err != nil {
169                 return nil, fmt.Errorf("invalid host: %v", err)
170         }
171         if ip = net.ParseIP(host); ip == nil {
172                 return nil, errors.New("invalid IP address")
173         }
174         // Ensure the IP is 4 bytes long for IPv4 addresses.
175         if ipv4 := ip.To4(); ipv4 != nil {
176                 ip = ipv4
177         }
178         // Parse the port numbers.
179         if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
180                 return nil, errors.New("invalid port")
181         }
182         udpPort = tcpPort
183         qv := u.Query()
184         if qv.Get("discport") != "" {
185                 udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
186                 if err != nil {
187                         return nil, errors.New("invalid discport in query")
188                 }
189         }
190         return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
191 }
192
193 // MustParseNode parses a node URL. It panics if the URL is not valid.
194 func MustParseNode(rawurl string) *Node {
195         n, err := ParseNode(rawurl)
196         if err != nil {
197                 panic("invalid node URL: " + err.Error())
198         }
199         return n
200 }
201
202 // MarshalText implements encoding.TextMarshaler.
203 func (n *Node) MarshalText() ([]byte, error) {
204         return []byte(n.String()), nil
205 }
206
207 // UnmarshalText implements encoding.TextUnmarshaler.
208 func (n *Node) UnmarshalText(text []byte) error {
209         dec, err := ParseNode(string(text))
210         if err == nil {
211                 *n = *dec
212         }
213         return err
214 }
215
216 // type nodeQueue []*Node
217 //
218 // // pushNew adds n to the end if it is not present.
219 // func (nl *nodeList) appendNew(n *Node) {
220 //      for _, entry := range n {
221 //              if entry == n {
222 //                      return
223 //              }
224 //      }
225 //      *nq = append(*nq, n)
226 // }
227 //
228 // // popRandom removes a random node. Nodes closer to
229 // // to the head of the beginning of the have a slightly higher probability.
230 // func (nl *nodeList) popRandom() *Node {
231 //      ix := rand.Intn(len(*nq))
232 //      //TODO: probability as mentioned above.
233 //      nl.removeIndex(ix)
234 // }
235 //
236 // func (nl *nodeList) removeIndex(i int) *Node {
237 //      slice = *nl
238 //      if len(*slice) <= i {
239 //              return nil
240 //      }
241 //      *nl = append(slice[:i], slice[i+1:]...)
242 // }
243
244 const nodeIDBits = 32
245
246 // NodeID is a unique identifier for each node.
247 // The node identifier is a marshaled elliptic curve public key.
248 type NodeID [32]byte
249
250 // NodeID prints as a long hexadecimal number.
251 func (n NodeID) String() string {
252         return fmt.Sprintf("%x", n[:])
253 }
254
255 // The Go syntax representation of a NodeID is a call to HexID.
256 func (n NodeID) GoString() string {
257         return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
258 }
259
260 // TerminalString returns a shortened hex string for terminal logging.
261 func (n NodeID) TerminalString() string {
262         return hex.EncodeToString(n[:8])
263 }
264
265 // HexID converts a hex string to a NodeID.
266 // The string may be prefixed with 0x.
267 func HexID(in string) (NodeID, error) {
268         var id NodeID
269         b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
270         if err != nil {
271                 return id, err
272         } else if len(b) != len(id) {
273                 return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
274         }
275         copy(id[:], b)
276         return id, nil
277 }
278
279 // ByteID converts a []byte to a NodeID.
280 func ByteID(in []byte) NodeID {
281         var id NodeID
282         for i := range id {
283                 id[i] = in[i]
284         }
285         return id
286 }
287
288 // MustHexID converts a hex string to a NodeID.
289 // It panics if the string is not a valid NodeID.
290 func MustHexID(in string) NodeID {
291         id, err := HexID(in)
292         if err != nil {
293                 panic(err)
294         }
295         return id
296 }
297
298 // PubkeyID returns a marshaled representation of the given public key.
299 func PubkeyID(pub *ecdsa.PublicKey) NodeID {
300         var id NodeID
301         pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
302         if len(pbytes)-1 != len(id) {
303                 panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
304         }
305         copy(id[:], pbytes[1:])
306         return id
307 }
308
309 //// Pubkey returns the public key represented by the node ID.
310 ////// It returns an error if the ID is not a point on the curve.
311 //func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
312 //      p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
313 //      half := len(id) / 2
314 //      p.X.SetBytes(id[:half])
315 //      p.Y.SetBytes(id[half:])
316 //      if !p.Curve.IsOnCurve(p.X, p.Y) {
317 //              return nil, errors.New("id is invalid secp256k1 curve point")
318 //      }
319 //      return p, nil
320 //}
321
322 //func (id NodeID) mustPubkey() ecdsa.PublicKey {
323 //      pk, err := id.Pubkey()
324 //      if err != nil {
325 //              panic(err)
326 //      }
327 //      return *pk
328 //}
329
330 // recoverNodeID computes the public key used to sign the
331 // given hash from the signature.
332 //func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
333 //      pubkey, err := crypto.Ecrecover(hash, sig)
334 //      if err != nil {
335 //              return id, err
336 //      }
337 //      if len(pubkey)-1 != len(id) {
338 //              return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
339 //      }
340 //      for i := range id {
341 //              id[i] = pubkey[i+1]
342 //      }
343 //      return id, nil
344 //}
345
346 // distcmp compares the distances a->target and b->target.
347 // Returns -1 if a is closer to target, 1 if b is closer to target
348 // and 0 if they are equal.
349 func distcmp(target, a, b common.Hash) int {
350         for i := range target {
351                 da := a[i] ^ target[i]
352                 db := b[i] ^ target[i]
353                 if da > db {
354                         return 1
355                 } else if da < db {
356                         return -1
357                 }
358         }
359         return 0
360 }
361
362 // table of leading zero counts for bytes [0..255]
363 var lzcount = [256]int{
364         8, 7, 6, 6, 5, 5, 5, 5,
365         4, 4, 4, 4, 4, 4, 4, 4,
366         3, 3, 3, 3, 3, 3, 3, 3,
367         3, 3, 3, 3, 3, 3, 3, 3,
368         2, 2, 2, 2, 2, 2, 2, 2,
369         2, 2, 2, 2, 2, 2, 2, 2,
370         2, 2, 2, 2, 2, 2, 2, 2,
371         2, 2, 2, 2, 2, 2, 2, 2,
372         1, 1, 1, 1, 1, 1, 1, 1,
373         1, 1, 1, 1, 1, 1, 1, 1,
374         1, 1, 1, 1, 1, 1, 1, 1,
375         1, 1, 1, 1, 1, 1, 1, 1,
376         1, 1, 1, 1, 1, 1, 1, 1,
377         1, 1, 1, 1, 1, 1, 1, 1,
378         1, 1, 1, 1, 1, 1, 1, 1,
379         1, 1, 1, 1, 1, 1, 1, 1,
380         0, 0, 0, 0, 0, 0, 0, 0,
381         0, 0, 0, 0, 0, 0, 0, 0,
382         0, 0, 0, 0, 0, 0, 0, 0,
383         0, 0, 0, 0, 0, 0, 0, 0,
384         0, 0, 0, 0, 0, 0, 0, 0,
385         0, 0, 0, 0, 0, 0, 0, 0,
386         0, 0, 0, 0, 0, 0, 0, 0,
387         0, 0, 0, 0, 0, 0, 0, 0,
388         0, 0, 0, 0, 0, 0, 0, 0,
389         0, 0, 0, 0, 0, 0, 0, 0,
390         0, 0, 0, 0, 0, 0, 0, 0,
391         0, 0, 0, 0, 0, 0, 0, 0,
392         0, 0, 0, 0, 0, 0, 0, 0,
393         0, 0, 0, 0, 0, 0, 0, 0,
394         0, 0, 0, 0, 0, 0, 0, 0,
395         0, 0, 0, 0, 0, 0, 0, 0,
396 }
397
398 // logdist returns the logarithmic distance between a and b, log2(a ^ b).
399 func logdist(a, b common.Hash) int {
400         lz := 0
401         for i := range a {
402                 x := a[i] ^ b[i]
403                 if x == 0 {
404                         lz += 8
405                 } else {
406                         lz += lzcount[x]
407                         break
408                 }
409         }
410         return len(a)*8 - lz
411 }
412
413 // hashAtDistance returns a random hash such that logdist(a, b) == n
414 func hashAtDistance(a common.Hash, n int) (b common.Hash) {
415         if n == 0 {
416                 return a
417         }
418         // flip bit at position n, fill the rest with random bits
419         b = a
420         pos := len(a) - n/8 - 1
421         bit := byte(0x01) << (byte(n%8) - 1)
422         if bit == 0 {
423                 pos++
424                 bit = 0x80
425         }
426         b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
427         for i := pos + 1; i < len(a); i++ {
428                 b[i] = byte(rand.New(rand.NewSource(time.Now().UnixNano())).Intn(255))
429         }
430         return b
431 }