OSDN Git Service

check the switch level logic (#1026)
authorPaladz <yzhu101@uottawa.ca>
Thu, 7 Jun 2018 05:33:02 +0000 (13:33 +0800)
committerGitHub <noreply@github.com>
Thu, 7 Jun 2018 05:33:02 +0000 (13:33 +0800)
* check the switch level logic

* edit for code review

p2p/fuzz.go [deleted file]
p2p/listener.go
p2p/listener_test.go
p2p/node_info.go
p2p/peer.go
p2p/peer_set.go
p2p/peer_test.go [deleted file]
p2p/public_ip.go
p2p/switch.go
p2p/switch_test.go [deleted file]

diff --git a/p2p/fuzz.go b/p2p/fuzz.go
deleted file mode 100644 (file)
index aefac98..0000000
+++ /dev/null
@@ -1,173 +0,0 @@
-package p2p
-
-import (
-       "math/rand"
-       "net"
-       "sync"
-       "time"
-)
-
-const (
-       // FuzzModeDrop is a mode in which we randomly drop reads/writes, connections or sleep
-       FuzzModeDrop = iota
-       // FuzzModeDelay is a mode in which we randomly sleep
-       FuzzModeDelay
-)
-
-// FuzzedConnection wraps any net.Conn and depending on the mode either delays
-// reads/writes or randomly drops reads/writes/connections.
-type FuzzedConnection struct {
-       conn net.Conn
-
-       mtx    sync.Mutex
-       start  <-chan time.Time
-       active bool
-
-       config *FuzzConnConfig
-}
-
-// FuzzConnConfig is a FuzzedConnection configuration.
-type FuzzConnConfig struct {
-       Mode         int
-       MaxDelay     time.Duration
-       ProbDropRW   float64
-       ProbDropConn float64
-       ProbSleep    float64
-}
-
-// DefaultFuzzConnConfig returns the default config.
-func DefaultFuzzConnConfig() *FuzzConnConfig {
-       return &FuzzConnConfig{
-               Mode:         FuzzModeDrop,
-               MaxDelay:     3 * time.Second,
-               ProbDropRW:   0.2,
-               ProbDropConn: 0.00,
-               ProbSleep:    0.00,
-       }
-}
-
-// FuzzConn creates a new FuzzedConnection. Fuzzing starts immediately.
-func FuzzConn(conn net.Conn) net.Conn {
-       return FuzzConnFromConfig(conn, DefaultFuzzConnConfig())
-}
-
-// FuzzConnFromConfig creates a new FuzzedConnection from a config. Fuzzing
-// starts immediately.
-func FuzzConnFromConfig(conn net.Conn, config *FuzzConnConfig) net.Conn {
-       return &FuzzedConnection{
-               conn:   conn,
-               start:  make(<-chan time.Time),
-               active: true,
-               config: config,
-       }
-}
-
-// FuzzConnAfter creates a new FuzzedConnection. Fuzzing starts when the
-// duration elapses.
-func FuzzConnAfter(conn net.Conn, d time.Duration) net.Conn {
-       return FuzzConnAfterFromConfig(conn, d, DefaultFuzzConnConfig())
-}
-
-// FuzzConnAfterFromConfig creates a new FuzzedConnection from a config.
-// Fuzzing starts when the duration elapses.
-func FuzzConnAfterFromConfig(conn net.Conn, d time.Duration, config *FuzzConnConfig) net.Conn {
-       return &FuzzedConnection{
-               conn:   conn,
-               start:  time.After(d),
-               active: false,
-               config: config,
-       }
-}
-
-// Config returns the connection's config.
-func (fc *FuzzedConnection) Config() *FuzzConnConfig {
-       return fc.config
-}
-
-// Read implements net.Conn.
-func (fc *FuzzedConnection) Read(data []byte) (n int, err error) {
-       if fc.fuzz() {
-               return 0, nil
-       }
-       return fc.conn.Read(data)
-}
-
-// Write implements net.Conn.
-func (fc *FuzzedConnection) Write(data []byte) (n int, err error) {
-       if fc.fuzz() {
-               return 0, nil
-       }
-       return fc.conn.Write(data)
-}
-
-// Close implements net.Conn.
-func (fc *FuzzedConnection) Close() error { return fc.conn.Close() }
-
-// LocalAddr implements net.Conn.
-func (fc *FuzzedConnection) LocalAddr() net.Addr { return fc.conn.LocalAddr() }
-
-// RemoteAddr implements net.Conn.
-func (fc *FuzzedConnection) RemoteAddr() net.Addr { return fc.conn.RemoteAddr() }
-
-// SetDeadline implements net.Conn.
-func (fc *FuzzedConnection) SetDeadline(t time.Time) error { return fc.conn.SetDeadline(t) }
-
-// SetReadDeadline implements net.Conn.
-func (fc *FuzzedConnection) SetReadDeadline(t time.Time) error {
-       return fc.conn.SetReadDeadline(t)
-}
-
-// SetWriteDeadline implements net.Conn.
-func (fc *FuzzedConnection) SetWriteDeadline(t time.Time) error {
-       return fc.conn.SetWriteDeadline(t)
-}
-
-func (fc *FuzzedConnection) randomDuration() time.Duration {
-       maxDelayMillis := int(fc.config.MaxDelay.Nanoseconds() / 1000)
-       return time.Millisecond * time.Duration(rand.Int()%maxDelayMillis)
-}
-
-// implements the fuzz (delay, kill conn)
-// and returns whether or not the read/write should be ignored
-func (fc *FuzzedConnection) fuzz() bool {
-       if !fc.shouldFuzz() {
-               return false
-       }
-
-       switch fc.config.Mode {
-       case FuzzModeDrop:
-               // randomly drop the r/w, drop the conn, or sleep
-               r := rand.Float64()
-               if r <= fc.config.ProbDropRW {
-                       return true
-               } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn {
-                       // XXX: can't this fail because machine precision?
-                       // XXX: do we need an error?
-                       fc.Close()
-                       return true
-               } else if r < fc.config.ProbDropRW+fc.config.ProbDropConn+fc.config.ProbSleep {
-                       time.Sleep(fc.randomDuration())
-               }
-       case FuzzModeDelay:
-               // sleep a bit
-               time.Sleep(fc.randomDuration())
-       }
-       return false
-}
-
-func (fc *FuzzedConnection) shouldFuzz() bool {
-       if fc.active {
-               return true
-       }
-
-       fc.mtx.Lock()
-       defer fc.mtx.Unlock()
-
-       select {
-       case <-fc.start:
-               fc.active = true
-               return true
-       default:
-               return false
-       }
-}
index a81abbe..21ea757 100644 (file)
@@ -6,9 +6,17 @@ import (
        "strconv"
        "time"
 
-       "github.com/bytom/p2p/upnp"
        log "github.com/sirupsen/logrus"
        cmn "github.com/tendermint/tmlibs/common"
+
+       "github.com/bytom/errors"
+       "github.com/bytom/p2p/upnp"
+)
+
+const (
+       numBufferedConnections = 10
+       defaultExternalPort    = 8770
+       tryListenTimes         = 5
 )
 
 //Listener subset of the methods of DefaultListener
@@ -20,21 +28,48 @@ type Listener interface {
        Stop() bool
 }
 
-//DefaultListener Implements bytomd server Listener
-type DefaultListener struct {
-       cmn.BaseService
+//getUPNPExternalAddress UPNP external address discovery & port mapping
+func getUPNPExternalAddress(externalPort, internalPort int) (*NetAddress, error) {
+       nat, err := upnp.Discover()
+       if err != nil {
+               return nil, errors.Wrap(err, "could not perform UPNP discover")
+       }
 
-       listener    net.Listener
-       intAddr     *NetAddress
-       extAddr     *NetAddress
-       connections chan net.Conn
+       ext, err := nat.GetExternalAddress()
+       if err != nil {
+               return nil, errors.Wrap(err, "could not perform UPNP external address")
+       }
+
+       if externalPort == 0 {
+               externalPort = defaultExternalPort
+       }
+       externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "bytomd", 0)
+       if err != nil {
+               return nil, errors.Wrap(err, "could not add UPNP port mapping")
+       }
+       return NewNetAddressIPPort(ext, uint16(externalPort)), nil
 }
 
-const (
-       numBufferedConnections = 10
-       defaultExternalPort    = 8770
-       tryListenSeconds       = 5
-)
+func getNaiveExternalAddress(port int, settleForLocal bool) *NetAddress {
+       addrs, err := net.InterfaceAddrs()
+       if err != nil {
+               cmn.PanicCrisis(cmn.Fmt("Could not fetch interface addresses: %v", err))
+       }
+
+       for _, a := range addrs {
+               ipnet, ok := a.(*net.IPNet)
+               if !ok {
+                       continue
+               }
+               if v4 := ipnet.IP.To4(); v4 == nil || (!settleForLocal && v4[0] == 127) {
+                       continue
+               }
+               return NewNetAddressIPPort(ipnet.IP, uint16(port))
+       }
+
+       log.Info("Node may not be connected to internet. Settling for local address")
+       return getNaiveExternalAddress(port, true)
+}
 
 func splitHostPort(addr string) (host string, port int) {
        host, portStr, err := net.SplitHostPort(addr)
@@ -48,64 +83,58 @@ func splitHostPort(addr string) (host string, port int) {
        return host, port
 }
 
+//DefaultListener Implements bytomd server Listener
+type DefaultListener struct {
+       cmn.BaseService
+
+       listener    net.Listener
+       intAddr     *NetAddress
+       extAddr     *NetAddress
+       connections chan net.Conn
+}
+
 //NewDefaultListener create a default listener
 func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) (Listener, bool) {
        // Local listen IP & port
        lAddrIP, lAddrPort := splitHostPort(lAddr)
 
-       // Create listener
-       var listener net.Listener
-       var err error
-       var getExtIP = false
-       var listenerStatus = false
-
-       for i := 0; i < tryListenSeconds; i++ {
+       listener, err := net.Listen(protocol, lAddr)
+       for i := 0; i < tryListenTimes && err != nil; i++ {
+               time.Sleep(time.Second * 1)
                listener, err = net.Listen(protocol, lAddr)
-               if err == nil {
-                       break
-               } else if i < tryListenSeconds-1 {
-                       time.Sleep(time.Second * 1)
-               }
        }
        if err != nil {
                cmn.PanicCrisis(err)
        }
-       // Actual listener local IP & port
-       listenerIP, listenerPort := splitHostPort(listener.Addr().String())
-       log.Info("Local listener", " ip:", listenerIP, " port:", listenerPort)
 
-       // Determine internal address...
-       var intAddr *NetAddress
-       intAddr, err = NewNetAddressString(lAddr)
+       intAddr, err := NewNetAddressString(lAddr)
        if err != nil {
                cmn.PanicCrisis(err)
        }
 
+       // Actual listener local IP & port
+       listenerIP, listenerPort := splitHostPort(listener.Addr().String())
+       log.Info("Local listener", " ip:", listenerIP, " port:", listenerPort)
+
        // Determine external address...
        var extAddr *NetAddress
-       //skipUPNP: If true, does not try getUPNPExternalAddress()
-       if !skipUPNP {
-               // If the lAddrIP is INADDR_ANY, try UPnP
-               if lAddrIP == "" || lAddrIP == "0.0.0.0" {
-                       extAddr = getUPNPExternalAddress(lAddrPort, listenerPort)
-                       if extAddr != nil {
-                               getExtIP = true
-                               listenerStatus = true
-                       }
-               }
+       var upnpMap bool
+       if !skipUPNP && (lAddrIP == "" || lAddrIP == "0.0.0.0") {
+               extAddr, err = getUPNPExternalAddress(lAddrPort, listenerPort)
+               upnpMap = err == nil
+               log.WithField("err", err).Info("get UPNP external address")
        }
+
        if extAddr == nil {
                if address := GetIP(); address.Success == true {
-                       extAddr = NewNetAddressIPPort(net.ParseIP(address.Ip), uint16(lAddrPort))
-                       getExtIP = true
+                       extAddr = NewNetAddressIPPort(net.ParseIP(address.IP), uint16(lAddrPort))
                }
        }
-       // Otherwise just use the local address...
        if extAddr == nil {
                extAddr = getNaiveExternalAddress(listenerPort, false)
        }
        if extAddr == nil {
-               cmn.PanicCrisis("Could not determine external address!")
+               cmn.PanicCrisis("could not determine external address!")
        }
 
        dl := &DefaultListener{
@@ -116,22 +145,16 @@ func NewDefaultListener(protocol string, lAddr string, skipUPNP bool) (Listener,
        }
        dl.BaseService = *cmn.NewBaseService(nil, "DefaultListener", dl)
        dl.Start() // Started upon construction
-
-       if !listenerStatus && getExtIP {
-               conn, err := net.DialTimeout("tcp", extAddr.String(), 3*time.Second)
-
-               if err != nil && conn == nil {
-                       log.Error("Could not open listen port")
-               }
-
-               if err == nil && conn != nil {
-                       log.Info("Success open listen port")
-                       listenerStatus = true
-                       conn.Close()
-               }
+       if upnpMap {
+               return dl, true
        }
 
-       return dl, listenerStatus
+       conn, err := net.DialTimeout("tcp", extAddr.String(), 3*time.Second)
+       if err != nil {
+               return dl, false
+       }
+       conn.Close()
+       return dl, true
 }
 
 //OnStart start listener
@@ -151,20 +174,16 @@ func (l *DefaultListener) OnStop() {
 func (l *DefaultListener) listenRoutine() {
        for {
                conn, err := l.listener.Accept()
-
                if !l.IsRunning() {
                        break // Go to cleanup
                }
-
                // listener wasn't stopped,
                // yet we encountered an error.
                if err != nil {
                        cmn.PanicCrisis(err)
                }
-
                l.connections <- conn
        }
-
        // Cleanup
        close(l.connections)
 }
@@ -193,56 +212,3 @@ func (l *DefaultListener) NetListener() net.Listener {
 func (l *DefaultListener) String() string {
        return fmt.Sprintf("Listener(@%v)", l.extAddr)
 }
-
-//getUPNPExternalAddress UPNP external address discovery & port mapping
-func getUPNPExternalAddress(externalPort, internalPort int) *NetAddress {
-       log.Info("Getting UPNP external address")
-       nat, err := upnp.Discover()
-       if err != nil {
-               log.Info("Could not perform UPNP discover. error:", err)
-               return nil
-       }
-
-       ext, err := nat.GetExternalAddress()
-       if err != nil {
-               log.Info("Could not perform UPNP external address. error:", err)
-               return nil
-       }
-
-       // UPnP can't seem to get the external port, so let's just be explicit.
-       if externalPort == 0 {
-               externalPort = defaultExternalPort
-       }
-
-       externalPort, err = nat.AddPortMapping("tcp", externalPort, internalPort, "bytomd", 0)
-       if err != nil {
-               log.Info("Could not add UPNP port mapping. error:", err)
-               return nil
-       }
-
-       log.Info("Got UPNP external address ", ext)
-       return NewNetAddressIPPort(ext, uint16(externalPort))
-}
-
-func getNaiveExternalAddress(port int, settleForLocal bool) *NetAddress {
-       addrs, err := net.InterfaceAddrs()
-       if err != nil {
-               cmn.PanicCrisis(cmn.Fmt("Could not fetch interface addresses: %v", err))
-       }
-
-       for _, a := range addrs {
-               ipnet, ok := a.(*net.IPNet)
-               if !ok {
-                       continue
-               }
-               v4 := ipnet.IP.To4()
-               if v4 == nil || (!settleForLocal && v4[0] == 127) {
-                       continue
-               } // loopback
-               return NewNetAddressIPPort(ipnet.IP, uint16(port))
-       }
-
-       // try again, but settle for local
-       log.Info("Node may not be connected to internet. Settling for local address")
-       return getNaiveExternalAddress(port, true)
-}
index e7c2e6d..bd2704b 100644 (file)
@@ -5,13 +5,11 @@ package p2p
 import (
        "bytes"
        "testing"
-
-       "github.com/tendermint/tmlibs/log"
 )
 
 func TestListener(t *testing.T) {
        // Create a listener
-       l, _ := NewDefaultListener("tcp", ":8001", true, log.TestingLogger())
+       l, _ := NewDefaultListener("tcp", ":8001", true)
 
        // Dial the listener
        lAddr := l.ExternalAddress()
index 25d8da5..5411635 100644 (file)
@@ -26,34 +26,25 @@ type NodeInfo struct {
 // CONTRACT: two nodes are compatible if the major version matches and network match
 // and they have at least one channel in common.
 func (info *NodeInfo) CompatibleWith(other *NodeInfo) error {
-       iMajor, iMinor, _, iErr := splitVersion(info.Version)
-       oMajor, oMinor, _, oErr := splitVersion(other.Version)
-
-       // if our own version number is not formatted right, we messed up
-       if iErr != nil {
-               return iErr
+       iMajor, iMinor, _, err := splitVersion(info.Version)
+       if err != nil {
+               return err
        }
 
-       // version number must be formatted correctly ("x.x.x")
-       if oErr != nil {
-               return oErr
+       oMajor, oMinor, _, err := splitVersion(other.Version)
+       if err != nil {
+               return err
        }
 
-       // major version must match
        if iMajor != oMajor {
                return fmt.Errorf("Peer is on a different major version. Got %v, expected %v", oMajor, iMajor)
        }
-
-       // minor version must match
        if iMinor != oMinor {
                return fmt.Errorf("Peer is on a different minor version. Got %v, expected %v", oMinor, iMinor)
        }
-
-       // nodes must be on the same network
        if info.Network != other.Network {
                return fmt.Errorf("Peer is on a different network. Got %v, expected %v", other.Network, info.Network)
        }
-
        return nil
 }
 
index d1e6331..32620c1 100644 (file)
@@ -22,104 +22,81 @@ type peerConn struct {
        conn     net.Conn // source connection
 }
 
-// Peer represent a bytom network node
-type Peer struct {
-       cmn.BaseService
-
-       // raw peerConn and the multiplex connection
-       *peerConn
-       mconn *connection.MConnection // multiplex connection
-
-       *NodeInfo
-       Key  string
-       Data *cmn.CMap // User data.
-}
-
 // PeerConfig is a Peer configuration.
 type PeerConfig struct {
-       AuthEnc bool `mapstructure:"auth_enc"` // authenticated encryption
-
-       // times are in seconds
-       HandshakeTimeout time.Duration `mapstructure:"handshake_timeout"`
-       DialTimeout      time.Duration `mapstructure:"dial_timeout"`
-
-       MConfig *connection.MConnConfig `mapstructure:"connection"`
-
-       Fuzz       bool            `mapstructure:"fuzz"` // fuzz connection (for testing)
-       FuzzConfig *FuzzConnConfig `mapstructure:"fuzz_config"`
+       HandshakeTimeout time.Duration           `mapstructure:"handshake_timeout"` // times are in seconds
+       DialTimeout      time.Duration           `mapstructure:"dial_timeout"`
+       MConfig          *connection.MConnConfig `mapstructure:"connection"`
 }
 
 // DefaultPeerConfig returns the default config.
 func DefaultPeerConfig(config *cfg.P2PConfig) *PeerConfig {
        return &PeerConfig{
-               AuthEnc:          true,
                HandshakeTimeout: time.Duration(config.HandshakeTimeout), // * time.Second,
                DialTimeout:      time.Duration(config.DialTimeout),      // * time.Second,
                MConfig:          connection.DefaultMConnConfig(),
-               Fuzz:             false,
-               FuzzConfig:       DefaultFuzzConnConfig(),
        }
 }
 
+// Peer represent a bytom network node
+type Peer struct {
+       cmn.BaseService
+       *NodeInfo
+       *peerConn
+       mconn *connection.MConnection // multiplex connection
+       Key   string
+}
+
+// OnStart implements BaseService.
+func (p *Peer) OnStart() error {
+       p.BaseService.OnStart()
+       _, err := p.mconn.Start()
+       return err
+}
+
+// OnStop implements BaseService.
+func (p *Peer) OnStop() {
+       p.BaseService.OnStop()
+       p.mconn.Stop()
+}
+
 func newPeer(pc *peerConn, nodeInfo *NodeInfo, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{})) *Peer {
        // Key and NodeInfo are set after Handshake
        p := &Peer{
                peerConn: pc,
                NodeInfo: nodeInfo,
-
-               Data: cmn.NewCMap(),
+               Key:      nodeInfo.PubKey.KeyString(),
        }
-       p.Key = nodeInfo.PubKey.KeyString()
        p.mconn = createMConnection(pc.conn, p, reactorsByCh, chDescs, onPeerError, pc.config.MConfig)
-
        p.BaseService = *cmn.NewBaseService(nil, "Peer", p)
        return p
 }
 
-func newOutboundPeer(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *cfg.P2PConfig) (*peerConn, error) {
-       return newOutboundPeerConn(addr, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig(config))
-}
-
-func newOutboundPeerConn(addr *NetAddress, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*peerConn, error) {
+func newOutboundPeerConn(addr *NetAddress, ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*peerConn, error) {
        conn, err := dial(addr, config)
        if err != nil {
                return nil, errors.Wrap(err, "Error dial peer")
        }
 
-       pc, err := newPeerConn(conn, true, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, config)
+       pc, err := newPeerConn(conn, true, ourNodePrivKey, config)
        if err != nil {
                conn.Close()
                return nil, err
        }
-
        return pc, nil
 }
 
-func newInboundPeerConn(conn net.Conn, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *cfg.P2PConfig) (*peerConn, error) {
-       return newPeerConn(conn, false, reactorsByCh, chDescs, onPeerError, ourNodePrivKey, DefaultPeerConfig(config))
+func newInboundPeerConn(conn net.Conn, ourNodePrivKey crypto.PrivKeyEd25519, config *cfg.P2PConfig) (*peerConn, error) {
+       return newPeerConn(conn, false, ourNodePrivKey, DefaultPeerConfig(config))
 }
 
-func newPeerConn(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{}), ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*peerConn, error) {
-       conn := rawConn
-
-       // Fuzz connection
-       if config.Fuzz {
-               // so we have time to do peer handshakes and get set up
-               conn = FuzzConnAfterFromConfig(conn, 10*time.Second, config.FuzzConfig)
-       }
-
-       // Encrypt connection
-       if config.AuthEnc {
-               conn.SetDeadline(time.Now().Add(config.HandshakeTimeout * time.Second))
-
-               var err error
-               conn, err = connection.MakeSecretConnection(conn, ourNodePrivKey)
-               if err != nil {
-                       return nil, errors.Wrap(err, "Error creating peer")
-               }
+func newPeerConn(rawConn net.Conn, outbound bool, ourNodePrivKey crypto.PrivKeyEd25519, config *PeerConfig) (*peerConn, error) {
+       rawConn.SetDeadline(time.Now().Add(config.HandshakeTimeout * time.Second))
+       conn, err := connection.MakeSecretConnection(rawConn, ourNodePrivKey)
+       if err != nil {
+               return nil, errors.Wrap(err, "Error creating peer")
        }
 
-       // Only the information we already have
        return &peerConn{
                config:   config,
                outbound: outbound,
@@ -127,11 +104,29 @@ func newPeerConn(rawConn net.Conn, outbound bool, reactorsByCh map[byte]Reactor,
        }, nil
 }
 
+// Addr returns peer's remote network address.
+func (p *Peer) Addr() net.Addr {
+       return p.conn.RemoteAddr()
+}
+
+// CanSend returns true if the send queue is not full, false otherwise.
+func (p *Peer) CanSend(chID byte) bool {
+       if !p.IsRunning() {
+               return false
+       }
+       return p.mconn.CanSend(chID)
+}
+
 // CloseConn should be used when the peer was created, but never started.
 func (pc *peerConn) CloseConn() {
        pc.conn.Close()
 }
 
+// Equals reports whenever 2 peers are actually represent the same node.
+func (p *Peer) Equals(other *Peer) bool {
+       return p.Key == other.Key
+}
+
 // HandshakeTimeout performs a handshake between a given node and the peer.
 // NOTE: blocking
 func (pc *peerConn) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duration) (*NodeInfo, error) {
@@ -139,8 +134,7 @@ func (pc *peerConn) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duratio
        pc.conn.SetDeadline(time.Now().Add(timeout))
 
        var peerNodeInfo = new(NodeInfo)
-       var err1 error
-       var err2 error
+       var err1, err2 error
        cmn.Parallel(
                func() {
                        var n int
@@ -164,98 +158,40 @@ func (pc *peerConn) HandshakeTimeout(ourNodeInfo *NodeInfo, timeout time.Duratio
        return peerNodeInfo, nil
 }
 
-// Addr returns peer's remote network address.
-func (p *Peer) Addr() net.Addr {
-       return p.conn.RemoteAddr()
+// IsOutbound returns true if the connection is outbound, false otherwise.
+func (p *Peer) IsOutbound() bool {
+       return p.outbound
 }
 
 // PubKey returns peer's public key.
 func (p *Peer) PubKey() crypto.PubKeyEd25519 {
-       if p.config.AuthEnc {
-               return p.conn.(*connection.SecretConnection).RemotePubKey()
-       }
-       if p.NodeInfo == nil {
-               panic("Attempt to get peer's PubKey before calling Handshake")
-       }
-       return p.PubKey()
-}
-
-// OnStart implements BaseService.
-func (p *Peer) OnStart() error {
-       p.BaseService.OnStart()
-       _, err := p.mconn.Start()
-       return err
-}
-
-// OnStop implements BaseService.
-func (p *Peer) OnStop() {
-       p.BaseService.OnStop()
-       p.mconn.Stop()
-}
-
-// Connection returns underlying MConnection.
-func (p *Peer) Connection() *connection.MConnection {
-       return p.mconn
-}
-
-// IsOutbound returns true if the connection is outbound, false otherwise.
-func (p *Peer) IsOutbound() bool {
-       return p.outbound
+       return p.conn.(*connection.SecretConnection).RemotePubKey()
 }
 
 // Send msg to the channel identified by chID byte. Returns false if the send
 // queue is full after timeout, specified by MConnection.
 func (p *Peer) Send(chID byte, msg interface{}) bool {
        if !p.IsRunning() {
-               // see Switch#Broadcast, where we fetch the list of peers and loop over
-               // them - while we're looping, one peer may be removed and stopped.
                return false
        }
        return p.mconn.Send(chID, msg)
 }
 
-// TrySend msg to the channel identified by chID byte. Immediately returns
-// false if the send queue is full.
-func (p *Peer) TrySend(chID byte, msg interface{}) bool {
-       if !p.IsRunning() {
-               return false
-       }
-       return p.mconn.TrySend(chID, msg)
-}
-
-// CanSend returns true if the send queue is not full, false otherwise.
-func (p *Peer) CanSend(chID byte) bool {
-       if !p.IsRunning() {
-               return false
-       }
-       return p.mconn.CanSend(chID)
-}
-
 // String representation.
 func (p *Peer) String() string {
        if p.outbound {
                return fmt.Sprintf("Peer{%v %v out}", p.mconn, p.Key[:12])
        }
-
        return fmt.Sprintf("Peer{%v %v in}", p.mconn, p.Key[:12])
 }
 
-// Equals reports whenever 2 peers are actually represent the same node.
-func (p *Peer) Equals(other *Peer) bool {
-       return p.Key == other.Key
-}
-
-// Get the data for a given key.
-func (p *Peer) Get(key string) interface{} {
-       return p.Data.Get(key)
-}
-
-func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) {
-       conn, err := addr.DialTimeout(config.DialTimeout * time.Second)
-       if err != nil {
-               return nil, err
+// TrySend msg to the channel identified by chID byte. Immediately returns
+// false if the send queue is full.
+func (p *Peer) TrySend(chID byte, msg interface{}) bool {
+       if !p.IsRunning() {
+               return false
        }
-       return conn, nil
+       return p.mconn.TrySend(chID, msg)
 }
 
 func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, chDescs []*connection.ChannelDescriptor, onPeerError func(*Peer, interface{}), config *connection.MConnConfig) *connection.MConnection {
@@ -270,6 +206,13 @@ func createMConnection(conn net.Conn, p *Peer, reactorsByCh map[byte]Reactor, ch
        onError := func(r interface{}) {
                onPeerError(p, r)
        }
-
        return connection.NewMConnectionWithConfig(conn, chDescs, onReceive, onError, config)
 }
+
+func dial(addr *NetAddress, config *PeerConfig) (net.Conn, error) {
+       conn, err := addr.DialTimeout(config.DialTimeout * time.Second)
+       if err != nil {
+               return nil, err
+       }
+       return conn, nil
+}
index d8d9ca2..e26746b 100644 (file)
@@ -40,15 +40,24 @@ func NewPeerSet() *PeerSet {
 func (ps *PeerSet) Add(peer *Peer) error {
        ps.mtx.Lock()
        defer ps.mtx.Unlock()
+
        if ps.lookup[peer.Key] != nil {
                return ErrDuplicatePeer
        }
 
-       index := len(ps.list)
-       // Appending is safe even with other goroutines
-       // iterating over the ps.list slice.
+       ps.lookup[peer.Key] = &peerSetItem{peer, len(ps.list)}
        ps.list = append(ps.list, peer)
-       ps.lookup[peer.Key] = &peerSetItem{peer, index}
+       return nil
+}
+
+// Get looks up a peer by the provided peerKey.
+func (ps *PeerSet) Get(peerKey string) *Peer {
+       ps.mtx.Lock()
+       defer ps.mtx.Unlock()
+       item, ok := ps.lookup[peerKey]
+       if ok {
+               return item.peer
+       }
        return nil
 }
 
@@ -61,15 +70,11 @@ func (ps *PeerSet) Has(peerKey string) bool {
        return ok
 }
 
-// Get looks up a peer by the provided peerKey.
-func (ps *PeerSet) Get(peerKey string) *Peer {
+// List threadsafe list of peers.
+func (ps *PeerSet) List() []*Peer {
        ps.mtx.Lock()
        defer ps.mtx.Unlock()
-       item, ok := ps.lookup[peerKey]
-       if ok {
-               return item.peer
-       }
-       return nil
+       return ps.list
 }
 
 // Remove discards peer if the peer was previously memoized.
@@ -101,7 +106,6 @@ func (ps *PeerSet) Remove(peer *Peer) {
        lastPeerItem.index = index
        ps.list = newList
        delete(ps.lookup, peer.Key)
-
 }
 
 // Size returns the number of unique items in the peerSet.
@@ -110,10 +114,3 @@ func (ps *PeerSet) Size() int {
        defer ps.mtx.Unlock()
        return len(ps.list)
 }
-
-// List threadsafe list of peers.
-func (ps *PeerSet) List() []*Peer {
-       ps.mtx.Lock()
-       defer ps.mtx.Unlock()
-       return ps.list
-}
diff --git a/p2p/peer_test.go b/p2p/peer_test.go
deleted file mode 100644 (file)
index 709fdf5..0000000
+++ /dev/null
@@ -1,159 +0,0 @@
-// +build !network
-
-package p2p
-
-import (
-       "net"
-       "testing"
-       "time"
-
-       log "github.com/sirupsen/logrus"
-       "github.com/stretchr/testify/assert"
-       "github.com/stretchr/testify/require"
-       crypto "github.com/tendermint/go-crypto"
-
-       cfg "github.com/bytom/config"
-)
-
-func TestPeerBasic(t *testing.T) {
-       assert, require := assert.New(t), require.New(t)
-
-       // simulate remote peer
-       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig(cfg.DefaultP2PConfig())}
-       rp.Start()
-       defer rp.Stop()
-
-       p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), DefaultPeerConfig(cfg.DefaultP2PConfig()))
-       require.Nil(err)
-
-       p.Start()
-       defer p.Stop()
-
-       assert.True(p.IsRunning())
-       assert.True(p.IsOutbound())
-       assert.False(p.IsPersistent())
-       p.makePersistent()
-       assert.True(p.IsPersistent())
-       assert.Equal(rp.Addr().String(), p.Addr().String())
-       assert.Equal(rp.PubKey(), p.PubKey())
-}
-
-func TestPeerWithoutAuthEnc(t *testing.T) {
-       assert, require := assert.New(t), require.New(t)
-
-       config := DefaultPeerConfig(cfg.DefaultP2PConfig())
-       config.AuthEnc = false
-
-       // simulate remote peer
-       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config}
-       rp.Start()
-       defer rp.Stop()
-
-       p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config)
-       require.Nil(err)
-
-       p.Start()
-       defer p.Stop()
-
-       assert.True(p.IsRunning())
-}
-
-func TestPeerSend(t *testing.T) {
-       assert, require := assert.New(t), require.New(t)
-
-       config := DefaultPeerConfig(cfg.DefaultP2PConfig())
-       config.AuthEnc = false
-
-       // simulate remote peer
-       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: config}
-       rp.Start()
-       defer rp.Stop()
-
-       p, err := createOutboundPeerAndPerformHandshake(rp.Addr(), config)
-       require.Nil(err)
-
-       p.Start()
-       defer p.Stop()
-
-       assert.True(p.CanSend(0x01))
-       assert.True(p.Send(0x01, "Asylum"))
-}
-
-func createOutboundPeerAndPerformHandshake(addr *NetAddress, config *PeerConfig) (*Peer, error) {
-       chDescs := []*ChannelDescriptor{
-               &ChannelDescriptor{ID: 0x01, Priority: 1},
-       }
-       reactorsByCh := map[byte]Reactor{0x01: NewTestReactor(chDescs, true)}
-       pk := crypto.GenPrivKeyEd25519()
-       p, err := newOutboundPeerWithConfig(addr, reactorsByCh, chDescs, func(p *Peer, r interface{}) {}, pk, config)
-       if err != nil {
-               return nil, err
-       }
-       err = p.HandshakeTimeout(&NodeInfo{
-               PubKey:  pk.PubKey().Unwrap().(crypto.PubKeyEd25519),
-               Moniker: "host_peer",
-               Network: "testing",
-               Version: "123.123.123",
-       }, 1*time.Second)
-       if err != nil {
-               return nil, err
-       }
-       return p, nil
-}
-
-type remotePeer struct {
-       PrivKey crypto.PrivKeyEd25519
-       Config  *PeerConfig
-       addr    *NetAddress
-       quit    chan struct{}
-}
-
-func (p *remotePeer) Addr() *NetAddress {
-       return p.addr
-}
-
-func (p *remotePeer) PubKey() crypto.PubKeyEd25519 {
-       return p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
-}
-
-func (p *remotePeer) Start() {
-       l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
-       if e != nil {
-               log.Fatalf("net.Listen tcp :0: %+v", e)
-       }
-       p.addr = NewNetAddress(l.Addr())
-       p.quit = make(chan struct{})
-       go p.accept(l)
-}
-
-func (p *remotePeer) Stop() {
-       close(p.quit)
-}
-
-func (p *remotePeer) accept(l net.Listener) {
-       for {
-               conn, err := l.Accept()
-               if err != nil {
-                       log.Fatalf("Failed to accept conn: %+v", err)
-               }
-               peer, err := newInboundPeerWithConfig(conn, make(map[byte]Reactor), make([]*ChannelDescriptor, 0), func(p *Peer, r interface{}) {}, p.PrivKey, p.Config)
-               if err != nil {
-                       log.Fatalf("Failed to create a peer: %+v", err)
-               }
-               err = peer.HandshakeTimeout(&NodeInfo{
-                       PubKey:  p.PrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519),
-                       Moniker: "remote_peer",
-                       Network: "testing",
-                       Version: "123.123.123",
-               }, 1*time.Second)
-               if err != nil {
-                       log.Fatalf("Failed to perform handshake: %+v", err)
-               }
-               select {
-               case <-p.quit:
-                       conn.Close()
-                       return
-               default:
-               }
-       }
-}
index 466641b..c0395f8 100644 (file)
@@ -20,15 +20,17 @@ var ipCheckServices = []string{
        "http://myexternalip.com/raw",
 }
 
-type IpResult struct {
+// IPResult is the ip check response
+type IPResult struct {
        Success bool
-       Ip      string
+       IP      string
 }
 
 var timeout = time.Duration(5)
 
-func GetIP() *IpResult {
-       resultCh := make(chan *IpResult, 1)
+// GetIP return the ip of the current host
+func GetIP() *IPResult {
+       resultCh := make(chan *IPResult, 1)
        for _, s := range ipCheckServices {
                go ipAddress(s, resultCh)
        }
@@ -38,12 +40,12 @@ func GetIP() *IpResult {
                case result := <-resultCh:
                        return result
                case <-time.After(time.Second * timeout):
-                       return &IpResult{false, ""}
+                       return &IPResult{false, ""}
                }
        }
 }
 
-func ipAddress(service string, done chan<- *IpResult) {
+func ipAddress(service string, done chan<- *IPResult) {
        client := http.Client{Timeout: time.Duration(timeout * time.Second)}
        resp, err := client.Get(service)
        if err != nil {
@@ -59,7 +61,7 @@ func ipAddress(service string, done chan<- *IpResult) {
        address := strings.TrimSpace(string(data))
        if ip := net.ParseIP(address); ip != nil && ip.To4() != nil {
                select {
-               case done <- &IpResult{true, address}:
+               case done <- &IPResult{true, address}:
                        return
                default:
                        return
index 121302b..275b57a 100644 (file)
@@ -89,86 +89,13 @@ func NewSwitch(config *cfg.P2PConfig, addrBook AddrBook, trustHistoryDB dbm.DB)
        return sw
 }
 
-// AddReactor adds the given reactor to the switch.
-// NOTE: Not goroutine safe.
-func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
-       // Validate the reactor.
-       // No two reactors can share the same channel.
-       reactorChannels := reactor.GetChannels()
-       for _, chDesc := range reactorChannels {
-               chID := chDesc.ID
-               if sw.reactorsByCh[chID] != nil {
-                       cmn.PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
-               }
-               sw.chDescs = append(sw.chDescs, chDesc)
-               sw.reactorsByCh[chID] = reactor
-       }
-       sw.reactors[name] = reactor
-       reactor.SetSwitch(sw)
-       return reactor
-}
-
-// Reactors returns a map of reactors registered on the switch.
-// NOTE: Not goroutine safe.
-func (sw *Switch) Reactors() map[string]Reactor {
-       return sw.reactors
-}
-
-// Reactor returns the reactor with the given name.
-// NOTE: Not goroutine safe.
-func (sw *Switch) Reactor(name string) Reactor {
-       return sw.reactors[name]
-}
-
-// AddListener adds the given listener to the switch for listening to incoming peer connections.
-// NOTE: Not goroutine safe.
-func (sw *Switch) AddListener(l Listener) {
-       sw.listeners = append(sw.listeners, l)
-}
-
-// Listeners returns the list of listeners the switch listens on.
-// NOTE: Not goroutine safe.
-func (sw *Switch) Listeners() []Listener {
-       return sw.listeners
-}
-
-// IsListening returns true if the switch has at least one listener.
-// NOTE: Not goroutine safe.
-func (sw *Switch) IsListening() bool {
-       return len(sw.listeners) > 0
-}
-
-// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
-       sw.nodeInfo = nodeInfo
-}
-
-// NodeInfo returns the switch's NodeInfo.
-// NOTE: Not goroutine safe.
-func (sw *Switch) NodeInfo() *NodeInfo {
-       return sw.nodeInfo
-}
-
-// SetNodePrivKey sets the switch's private key for authenticated encryption.
-// NOTE: Not goroutine safe.
-func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
-       sw.nodePrivKey = nodePrivKey
-       if sw.nodeInfo != nil {
-               sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
-       }
-}
-
 // OnStart implements BaseService. It starts all the reactors, peers, and listeners.
 func (sw *Switch) OnStart() error {
-       // Start reactors
        for _, reactor := range sw.reactors {
-               _, err := reactor.Start()
-               if err != nil {
+               if _, err := reactor.Start(); err != nil {
                        return err
                }
        }
-       // Start listeners
        for _, listener := range sw.listeners {
                go sw.listenerRoutine(listener)
        }
@@ -177,22 +104,37 @@ func (sw *Switch) OnStart() error {
 
 // OnStop implements BaseService. It stops all listeners, peers, and reactors.
 func (sw *Switch) OnStop() {
-       // Stop listeners
        for _, listener := range sw.listeners {
                listener.Stop()
        }
        sw.listeners = nil
-       // Stop peers
+
        for _, peer := range sw.peers.List() {
                peer.Stop()
                sw.peers.Remove(peer)
        }
-       // Stop reactors
+
        for _, reactor := range sw.reactors {
                reactor.Stop()
        }
 }
 
+//AddBannedPeer add peer to blacklist
+func (sw *Switch) AddBannedPeer(peer *Peer) error {
+       sw.mtx.Lock()
+       defer sw.mtx.Unlock()
+
+       key := peer.NodeInfo.RemoteAddrHost()
+       sw.bannedPeer[key] = time.Now().Add(defaultBanDuration)
+       datajson, err := json.Marshal(sw.bannedPeer)
+       if err != nil {
+               return err
+       }
+
+       sw.db.Set([]byte(bannedPeerKey), datajson)
+       return nil
+}
+
 // AddPeer performs the P2P handshake with a peer
 // that already has a SecretConnection. If all goes well,
 // it starts the peer and adds it to the switch.
@@ -203,14 +145,12 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
        if err != nil {
                return err
        }
-       // Check version, chain id
+
        if err := sw.nodeInfo.CompatibleWith(peerNodeInfo); err != nil {
                return err
        }
 
        peer := newPeer(pc, peerNodeInfo, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError)
-
-       //filter peer
        if err := sw.filterConnByPeer(peer); err != nil {
                return err
        }
@@ -221,102 +161,54 @@ func (sw *Switch) AddPeer(pc *peerConn) error {
                        return err
                }
        }
-
-       // Add the peer to .peers.
-       // We start it first so that a peer in the list is safe to Stop.
-       // It should not err since we already checked peers.Has()
-       if err := sw.peers.Add(peer); err != nil {
-               return err
-       }
-
-       log.Info("Added peer:", peer)
-       return nil
+       return sw.peers.Add(peer)
 }
 
-func (sw *Switch) startInitPeer(peer *Peer) error {
-       peer.Start() // spawn send/recv routines
-       for _, reactor := range sw.reactors {
-               if err := reactor.AddPeer(peer); err != nil {
-                       return err
+// AddReactor adds the given reactor to the switch.
+// NOTE: Not goroutine safe.
+func (sw *Switch) AddReactor(name string, reactor Reactor) Reactor {
+       // Validate the reactor.
+       // No two reactors can share the same channel.
+       for _, chDesc := range reactor.GetChannels() {
+               chID := chDesc.ID
+               if sw.reactorsByCh[chID] != nil {
+                       cmn.PanicSanity(fmt.Sprintf("Channel %X has multiple reactors %v & %v", chID, sw.reactorsByCh[chID], reactor))
                }
+               sw.chDescs = append(sw.chDescs, chDesc)
+               sw.reactorsByCh[chID] = reactor
        }
-       return nil
-}
-
-func (sw *Switch) dialSeed(addr *NetAddress) {
-       err := sw.DialPeerWithAddress(addr)
-       if err != nil {
-               log.Info("Error dialing seed:", addr.String())
-       }
-}
-
-func (sw *Switch) addrBookDelSelf() error {
-       addr, err := NewNetAddressString(sw.nodeInfo.ListenAddr)
-       if err != nil {
-               return err
-       }
-       // remove the given address from the address book if we're added it earlier
-       sw.addrBook.RemoveAddress(addr)
-       // add the given address to the address book to avoid dialing ourselves
-       // again this is our public address
-       sw.addrBook.AddOurAddress(addr)
-       return nil
-}
-
-func (sw *Switch) filterConnByIP(ip string) error {
-       if err := sw.checkBannedPeer(ip); err != nil {
-               return ErrConnectBannedPeer
-       }
-
-       if ip == sw.nodeInfo.ListenHost() {
-               sw.addrBookDelSelf()
-               return ErrConnectSelf
-       }
-
-       return nil
+       sw.reactors[name] = reactor
+       reactor.SetSwitch(sw)
+       return reactor
 }
 
-func (sw *Switch) filterConnByPeer(peer *Peer) error {
-       if err := sw.checkBannedPeer(peer.RemoteAddrHost()); err != nil {
-               return ErrConnectBannedPeer
-       }
-
-       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
-               sw.addrBookDelSelf()
-               return ErrConnectSelf
-       }
-
-       // Check for duplicate peer
-       if sw.peers.Has(peer.Key) {
-               return ErrDuplicatePeer
-       }
-       return nil
+// AddListener adds the given listener to the switch for listening to incoming peer connections.
+// NOTE: Not goroutine safe.
+func (sw *Switch) AddListener(l Listener) {
+       sw.listeners = append(sw.listeners, l)
 }
 
 //DialPeerWithAddress dial node from net address
 func (sw *Switch) DialPeerWithAddress(addr *NetAddress) error {
        log.Debug("Dialing peer address:", addr)
-
+       sw.dialing.Set(addr.IP.String(), addr)
+       defer sw.dialing.Delete(addr.IP.String())
        if err := sw.filterConnByIP(addr.IP.String()); err != nil {
                return err
        }
 
-       sw.dialing.Set(addr.IP.String(), addr)
-       defer sw.dialing.Delete(addr.IP.String())
-
-       pc, err := newOutboundPeerConn(addr, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, sw.peerConfig)
+       pc, err := newOutboundPeerConn(addr, sw.nodePrivKey, sw.peerConfig)
        if err != nil {
-               log.Debug("Failed to dial peer", " address:", addr, " error:", err)
+               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on newOutboundPeerConn")
                return err
        }
 
-       err = sw.AddPeer(pc)
-       if err != nil {
-               log.Info("Failed to add peer:", addr, " err:", err)
+       if err = sw.AddPeer(pc); err != nil {
+               log.WithFields(log.Fields{"address": addr, " err": err}).Debug("DialPeer fail on switch AddPeer")
                pc.CloseConn()
                return err
        }
-       log.Info("Dialed and added peer:", addr)
+       log.Debug("DialPeer added peer:", addr)
        return nil
 }
 
@@ -325,6 +217,18 @@ func (sw *Switch) IsDialing(addr *NetAddress) bool {
        return sw.dialing.Has(addr.IP.String())
 }
 
+// IsListening returns true if the switch has at least one listener.
+// NOTE: Not goroutine safe.
+func (sw *Switch) IsListening() bool {
+       return len(sw.listeners) > 0
+}
+
+// Listeners returns the list of listeners the switch listens on.
+// NOTE: Not goroutine safe.
+func (sw *Switch) Listeners() []Listener {
+       return sw.listeners
+}
+
 // NumPeers Returns the count of outbound/inbound and outbound-dialing peers.
 func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
        peers := sw.peers.List()
@@ -339,103 +243,157 @@ func (sw *Switch) NumPeers() (outbound, inbound, dialing int) {
        return
 }
 
+// NodeInfo returns the switch's NodeInfo.
+// NOTE: Not goroutine safe.
+func (sw *Switch) NodeInfo() *NodeInfo {
+       return sw.nodeInfo
+}
+
 //Peers return switch peerset
 func (sw *Switch) Peers() *PeerSet {
        return sw.peers
 }
 
+// SetNodeInfo sets the switch's NodeInfo for checking compatibility and handshaking with other nodes.
+// NOTE: Not goroutine safe.
+func (sw *Switch) SetNodeInfo(nodeInfo *NodeInfo) {
+       sw.nodeInfo = nodeInfo
+}
+
+// SetNodePrivKey sets the switch's private key for authenticated encryption.
+// NOTE: Not goroutine safe.
+func (sw *Switch) SetNodePrivKey(nodePrivKey crypto.PrivKeyEd25519) {
+       sw.nodePrivKey = nodePrivKey
+       if sw.nodeInfo != nil {
+               sw.nodeInfo.PubKey = nodePrivKey.PubKey().Unwrap().(crypto.PubKeyEd25519)
+       }
+}
+
 // StopPeerForError disconnects from a peer due to external error.
 func (sw *Switch) StopPeerForError(peer *Peer, reason interface{}) {
-       log.Info("Stopping peer for error.", " peer:", peer, " err:", reason)
+       log.WithFields(log.Fields{"peer": peer, " err": reason}).Debug("stopping peer for error")
        sw.stopAndRemovePeer(peer, reason)
 }
 
 // StopPeerGracefully disconnect from a peer gracefully.
 func (sw *Switch) StopPeerGracefully(peer *Peer) {
-       log.Info("Stopping peer gracefully")
        sw.stopAndRemovePeer(peer, nil)
 }
 
-func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
-       for _, reactor := range sw.reactors {
-               reactor.RemovePeer(peer, reason)
-       }
-       sw.peers.Remove(peer)
-       peer.Stop()
-}
-
-func (sw *Switch) listenerRoutine(l Listener) {
-       for {
-               inConn, ok := <-l.Connections()
-               if !ok {
-                       break
-               }
-
-               // disconnect if we alrady have 2 * MaxNumPeers, we do this because we wanna address book get exchanged even if
-               // the connect is full. The pex will disconnect the peer after address exchange, the max connected peer won't
-               // be double of MaxNumPeers
-               if sw.peers.Size() >= sw.Config.MaxNumPeers*2 {
-                       inConn.Close()
-                       log.Info("Ignoring inbound connection: already have enough peers.")
-                       continue
-               }
-
-               // New inbound connection!
-               err := sw.addPeerWithConnection(inConn)
-               if err != nil {
-                       log.Info("Ignoring inbound connection: error while adding peer.", " address:", inConn.RemoteAddr().String(), " error:", err)
-                       continue
-               }
-       }
-}
-
 func (sw *Switch) addPeerWithConnection(conn net.Conn) error {
-       peerConn, err := newInboundPeerConn(conn, sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, sw.Config)
+       peerConn, err := newInboundPeerConn(conn, sw.nodePrivKey, sw.Config)
        if err != nil {
                conn.Close()
                return err
        }
+
        if err = sw.AddPeer(peerConn); err != nil {
                conn.Close()
                return err
        }
+       return nil
+}
 
+func (sw *Switch) addrBookDelSelf() error {
+       addr, err := NewNetAddressString(sw.nodeInfo.ListenAddr)
+       if err != nil {
+               return err
+       }
+
+       sw.addrBook.RemoveAddress(addr)
+       sw.addrBook.AddOurAddress(addr)
        return nil
 }
 
-//AddBannedPeer add peer to blacklist
-func (sw *Switch) AddBannedPeer(peer *Peer) error {
+func (sw *Switch) checkBannedPeer(peer string) error {
        sw.mtx.Lock()
        defer sw.mtx.Unlock()
-       key := peer.NodeInfo.RemoteAddrHost()
-       sw.bannedPeer[key] = time.Now().Add(defaultBanDuration)
-       datajson, err := json.Marshal(sw.bannedPeer)
-       if err != nil {
-               return err
+
+       if banEnd, ok := sw.bannedPeer[peer]; ok {
+               if time.Now().Before(banEnd) {
+                       return ErrConnectBannedPeer
+               }
+               sw.delBannedPeer(peer)
        }
-       sw.db.Set([]byte(bannedPeerKey), datajson)
        return nil
 }
 
 func (sw *Switch) delBannedPeer(addr string) error {
+       sw.mtx.Lock()
+       defer sw.mtx.Unlock()
+
        delete(sw.bannedPeer, addr)
        datajson, err := json.Marshal(sw.bannedPeer)
        if err != nil {
                return err
        }
+
        sw.db.Set([]byte(bannedPeerKey), datajson)
        return nil
 }
 
-func (sw *Switch) checkBannedPeer(peer string) error {
-       sw.mtx.Lock()
-       defer sw.mtx.Unlock()
+func (sw *Switch) filterConnByIP(ip string) error {
+       if ip == sw.nodeInfo.ListenHost() {
+               sw.addrBookDelSelf()
+               return ErrConnectSelf
+       }
+       return sw.checkBannedPeer(ip)
+}
 
-       if banEnd, ok := sw.bannedPeer[peer]; ok {
-               if time.Now().Before(banEnd) {
-                       return ErrConnectBannedPeer
+func (sw *Switch) filterConnByPeer(peer *Peer) error {
+       if err := sw.checkBannedPeer(peer.RemoteAddrHost()); err != nil {
+               return err
+       }
+
+       if sw.nodeInfo.PubKey.Equals(peer.PubKey().Wrap()) {
+               sw.addrBookDelSelf()
+               return ErrConnectSelf
+       }
+
+       if sw.peers.Has(peer.Key) {
+               return ErrDuplicatePeer
+       }
+       return nil
+}
+
+func (sw *Switch) listenerRoutine(l Listener) {
+       for {
+               inConn, ok := <-l.Connections()
+               if !ok {
+                       break
+               }
+
+               // disconnect if we alrady have 2 * MaxNumPeers, we do this because we wanna address book get exchanged even if
+               // the connect is full. The pex will disconnect the peer after address exchange, the max connected peer won't
+               // be double of MaxNumPeers
+               if sw.peers.Size() >= sw.Config.MaxNumPeers*2 {
+                       inConn.Close()
+                       log.Info("Ignoring inbound connection: already have enough peers.")
+                       continue
+               }
+
+               // New inbound connection!
+               if err := sw.addPeerWithConnection(inConn); err != nil {
+                       log.Info("Ignoring inbound connection: error while adding peer.", " address:", inConn.RemoteAddr().String(), " error:", err)
+                       continue
+               }
+       }
+}
+
+func (sw *Switch) startInitPeer(peer *Peer) error {
+       peer.Start() // spawn send/recv routines
+       for _, reactor := range sw.reactors {
+               if err := reactor.AddPeer(peer); err != nil {
+                       return err
                }
-               sw.delBannedPeer(peer)
        }
        return nil
 }
+
+func (sw *Switch) stopAndRemovePeer(peer *Peer, reason interface{}) {
+       for _, reactor := range sw.reactors {
+               reactor.RemovePeer(peer, reason)
+       }
+       sw.peers.Remove(peer)
+       peer.Stop()
+}
diff --git a/p2p/switch_test.go b/p2p/switch_test.go
deleted file mode 100644 (file)
index a8f41f5..0000000
+++ /dev/null
@@ -1,338 +0,0 @@
-// +build !network
-
-package p2p
-
-import (
-       "bytes"
-       "fmt"
-       "net"
-       "sync"
-       "testing"
-       "time"
-
-       "github.com/stretchr/testify/assert"
-       "github.com/stretchr/testify/require"
-       crypto "github.com/tendermint/go-crypto"
-       wire "github.com/tendermint/go-wire"
-
-       cfg "github.com/bytom/config"
-       "github.com/tendermint/tmlibs/log"
-)
-
-var (
-       config *cfg.P2PConfig
-)
-
-func init() {
-       config = cfg.DefaultP2PConfig()
-       config.PexReactor = true
-}
-
-type PeerMessage struct {
-       PeerKey string
-       Bytes   []byte
-       Counter int
-}
-
-type TestReactor struct {
-       BaseReactor
-
-       mtx          sync.Mutex
-       channels     []*ChannelDescriptor
-       peersAdded   []*Peer
-       peersRemoved []*Peer
-       logMessages  bool
-       msgsCounter  int
-       msgsReceived map[byte][]PeerMessage
-}
-
-func NewTestReactor(channels []*ChannelDescriptor, logMessages bool) *TestReactor {
-       tr := &TestReactor{
-               channels:     channels,
-               logMessages:  logMessages,
-               msgsReceived: make(map[byte][]PeerMessage),
-       }
-       tr.BaseReactor = *NewBaseReactor("TestReactor", tr)
-       tr.SetLogger(log.TestingLogger())
-       return tr
-}
-
-func (tr *TestReactor) GetChannels() []*ChannelDescriptor {
-       return tr.channels
-}
-
-func (tr *TestReactor) AddPeer(peer *Peer) error {
-       tr.mtx.Lock()
-       defer tr.mtx.Unlock()
-       tr.peersAdded = append(tr.peersAdded, peer)
-       return nil
-}
-
-func (tr *TestReactor) RemovePeer(peer *Peer, reason interface{}) {
-       tr.mtx.Lock()
-       defer tr.mtx.Unlock()
-       tr.peersRemoved = append(tr.peersRemoved, peer)
-}
-
-func (tr *TestReactor) Receive(chID byte, peer *Peer, msgBytes []byte) {
-       if tr.logMessages {
-               tr.mtx.Lock()
-               defer tr.mtx.Unlock()
-               //fmt.Printf("Received: %X, %X\n", chID, msgBytes)
-               tr.msgsReceived[chID] = append(tr.msgsReceived[chID], PeerMessage{peer.Key, msgBytes, tr.msgsCounter})
-               tr.msgsCounter++
-       }
-}
-
-func (tr *TestReactor) getMsgs(chID byte) []PeerMessage {
-       tr.mtx.Lock()
-       defer tr.mtx.Unlock()
-       return tr.msgsReceived[chID]
-}
-
-//-----------------------------------------------------------------------------
-
-// convenience method for creating two switches connected to each other.
-// XXX: note this uses net.Pipe and not a proper TCP conn
-func makeSwitchPair(t testing.TB, initSwitch func(int, *Switch) *Switch) (*Switch, *Switch) {
-       // Create two switches that will be interconnected.
-       switches := MakeConnectedSwitches(config, 2, initSwitch, Connect2Switches)
-       return switches[0], switches[1]
-}
-
-func initSwitchFunc(i int, sw *Switch) *Switch {
-       // Make two reactors of two channels each
-       sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
-               {ID: byte(0x00), Priority: 10},
-               {ID: byte(0x01), Priority: 10},
-       }, true))
-       sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
-               {ID: byte(0x02), Priority: 10},
-               {ID: byte(0x03), Priority: 10},
-       }, true))
-       return sw
-}
-
-func TestSwitches(t *testing.T) {
-       s1, s2 := makeSwitchPair(t, initSwitchFunc)
-       defer s1.Stop()
-       defer s2.Stop()
-
-       if s1.Peers().Size() != 1 {
-               t.Errorf("Expected exactly 1 peer in s1, got %v", s1.Peers().Size())
-       }
-       if s2.Peers().Size() != 1 {
-               t.Errorf("Expected exactly 1 peer in s2, got %v", s2.Peers().Size())
-       }
-
-       // Lets send some messages
-       ch0Msg := "channel zero"
-       ch1Msg := "channel foo"
-       ch2Msg := "channel bar"
-
-       s1.Broadcast(byte(0x00), ch0Msg)
-       s1.Broadcast(byte(0x01), ch1Msg)
-       s1.Broadcast(byte(0x02), ch2Msg)
-
-       // Wait for things to settle...
-       time.Sleep(5000 * time.Millisecond)
-
-       // Check message on ch0
-       ch0Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x00))
-       if len(ch0Msgs) != 1 {
-               t.Errorf("Expected to have received 1 message in ch0")
-       }
-       if !bytes.Equal(ch0Msgs[0].Bytes, wire.BinaryBytes(ch0Msg)) {
-               t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch0Msg), ch0Msgs[0].Bytes)
-       }
-
-       // Check message on ch1
-       ch1Msgs := s2.Reactor("foo").(*TestReactor).getMsgs(byte(0x01))
-       if len(ch1Msgs) != 1 {
-               t.Errorf("Expected to have received 1 message in ch1")
-       }
-       if !bytes.Equal(ch1Msgs[0].Bytes, wire.BinaryBytes(ch1Msg)) {
-               t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch1Msg), ch1Msgs[0].Bytes)
-       }
-
-       // Check message on ch2
-       ch2Msgs := s2.Reactor("bar").(*TestReactor).getMsgs(byte(0x02))
-       if len(ch2Msgs) != 1 {
-               t.Errorf("Expected to have received 1 message in ch2")
-       }
-       if !bytes.Equal(ch2Msgs[0].Bytes, wire.BinaryBytes(ch2Msg)) {
-               t.Errorf("Unexpected message bytes. Wanted: %X, Got: %X", wire.BinaryBytes(ch2Msg), ch2Msgs[0].Bytes)
-       }
-
-}
-
-func TestConnAddrFilter(t *testing.T) {
-       s1 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-       s2 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-
-       c1, c2 := net.Pipe()
-
-       s1.SetAddrFilter(func(addr net.Addr) error {
-               if addr.String() == c1.RemoteAddr().String() {
-                       return fmt.Errorf("Error: pipe is blacklisted")
-               }
-               return nil
-       })
-
-       // connect to good peer
-       go func() {
-               err := s1.addPeerWithConnection(c1)
-               assert.NotNil(t, err, "expected err")
-       }()
-       go func() {
-               err := s2.addPeerWithConnection(c2)
-               assert.NotNil(t, err, "expected err")
-       }()
-
-       // Wait for things to happen, peers to get added...
-       time.Sleep(100 * time.Millisecond * time.Duration(4))
-
-       defer s1.Stop()
-       defer s2.Stop()
-       if s1.Peers().Size() != 0 {
-               t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
-       }
-       if s2.Peers().Size() != 0 {
-               t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
-       }
-}
-
-func TestConnPubKeyFilter(t *testing.T) {
-       s1 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-       s2 := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-
-       c1, c2 := net.Pipe()
-
-       // set pubkey filter
-       s1.SetPubKeyFilter(func(pubkey crypto.PubKeyEd25519) error {
-               if bytes.Equal(pubkey.Bytes(), s2.nodeInfo.PubKey.Bytes()) {
-                       return fmt.Errorf("Error: pipe is blacklisted")
-               }
-               return nil
-       })
-
-       // connect to good peer
-       go func() {
-               err := s1.addPeerWithConnection(c1)
-               assert.NotNil(t, err, "expected err")
-       }()
-       go func() {
-               err := s2.addPeerWithConnection(c2)
-               assert.NotNil(t, err, "expected err")
-       }()
-
-       // Wait for things to happen, peers to get added...
-       time.Sleep(100 * time.Millisecond * time.Duration(4))
-
-       defer s1.Stop()
-       defer s2.Stop()
-       if s1.Peers().Size() != 0 {
-               t.Errorf("Expected s1 not to connect to peers, got %d", s1.Peers().Size())
-       }
-       if s2.Peers().Size() != 0 {
-               t.Errorf("Expected s2 not to connect to peers, got %d", s2.Peers().Size())
-       }
-}
-
-func TestSwitchStopsNonPersistentPeerOnError(t *testing.T) {
-       assert, require := assert.New(t), require.New(t)
-
-       sw := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-       sw.Start()
-       defer sw.Stop()
-
-       // simulate remote peer
-       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig(config)}
-       rp.Start()
-       defer rp.Stop()
-
-       peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config)
-       require.Nil(err)
-       err = sw.AddPeer(peer)
-       require.Nil(err)
-
-       // simulate failure by closing connection
-       peer.CloseConn()
-
-       time.Sleep(100 * time.Millisecond)
-
-       assert.Zero(sw.Peers().Size())
-       assert.False(peer.IsRunning())
-}
-
-func TestSwitchReconnectsToPersistentPeer(t *testing.T) {
-       assert, require := assert.New(t), require.New(t)
-
-       sw := makeSwitch(config, 1, "testing", "123.123.123", initSwitchFunc)
-       sw.Start()
-       defer sw.Stop()
-
-       // simulate remote peer
-       rp := &remotePeer{PrivKey: crypto.GenPrivKeyEd25519(), Config: DefaultPeerConfig(config)}
-       rp.Start()
-       defer rp.Stop()
-
-       peer, err := newOutboundPeer(rp.Addr(), sw.reactorsByCh, sw.chDescs, sw.StopPeerForError, sw.nodePrivKey, config)
-       peer.makePersistent()
-       require.Nil(err)
-       err = sw.AddPeer(peer)
-       require.Nil(err)
-
-       // simulate failure by closing connection
-       peer.CloseConn()
-
-       // TODO: actually detect the disconnection and wait for reconnect
-       time.Sleep(100 * time.Millisecond)
-
-       assert.NotZero(sw.Peers().Size())
-       assert.False(peer.IsRunning())
-}
-
-func BenchmarkSwitches(b *testing.B) {
-       b.StopTimer()
-
-       s1, s2 := makeSwitchPair(b, func(i int, sw *Switch) *Switch {
-               // Make bar reactors of bar channels each
-               sw.AddReactor("foo", NewTestReactor([]*ChannelDescriptor{
-                       {ID: byte(0x00), Priority: 10},
-                       {ID: byte(0x01), Priority: 10},
-               }, false))
-               sw.AddReactor("bar", NewTestReactor([]*ChannelDescriptor{
-                       {ID: byte(0x02), Priority: 10},
-                       {ID: byte(0x03), Priority: 10},
-               }, false))
-               return sw
-       })
-       defer s1.Stop()
-       defer s2.Stop()
-
-       // Allow time for goroutines to boot up
-       time.Sleep(1000 * time.Millisecond)
-       b.StartTimer()
-
-       numSuccess, numFailure := 0, 0
-
-       // Send random message from foo channel to another
-       for i := 0; i < b.N; i++ {
-               chID := byte(i % 4)
-               successChan := s1.Broadcast(chID, "test data")
-               for s := range successChan {
-                       if s {
-                               numSuccess++
-                       } else {
-                               numFailure++
-                       }
-               }
-       }
-
-       b.Logf("success: %v, failure: %v", numSuccess, numFailure)
-
-       // Allow everything to flush before stopping switches & closing connections.
-       b.StopTimer()
-       time.Sleep(1000 * time.Millisecond)
-}