OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / golang.org / x / crypto / ssh / test / test_unix_test.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // +build darwin dragonfly freebsd linux netbsd openbsd plan9
6
7 package test
8
9 // functional test harness for unix.
10
11 import (
12         "bytes"
13         "fmt"
14         "io/ioutil"
15         "log"
16         "net"
17         "os"
18         "os/exec"
19         "os/user"
20         "path/filepath"
21         "testing"
22         "text/template"
23
24         "golang.org/x/crypto/ssh"
25         "golang.org/x/crypto/ssh/testdata"
26 )
27
28 const sshd_config = `
29 Protocol 2
30 HostKey {{.Dir}}/id_rsa
31 HostKey {{.Dir}}/id_dsa
32 HostKey {{.Dir}}/id_ecdsa
33 HostCertificate {{.Dir}}/id_rsa-cert.pub
34 Pidfile {{.Dir}}/sshd.pid
35 #UsePrivilegeSeparation no
36 KeyRegenerationInterval 3600
37 ServerKeyBits 768
38 SyslogFacility AUTH
39 LogLevel DEBUG2
40 LoginGraceTime 120
41 PermitRootLogin no
42 StrictModes no
43 RSAAuthentication yes
44 PubkeyAuthentication yes
45 AuthorizedKeysFile      {{.Dir}}/authorized_keys
46 TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
47 IgnoreRhosts yes
48 RhostsRSAAuthentication no
49 HostbasedAuthentication no
50 PubkeyAcceptedKeyTypes=*
51 `
52
53 var configTmpl = template.Must(template.New("").Parse(sshd_config))
54
55 type server struct {
56         t          *testing.T
57         cleanup    func() // executed during Shutdown
58         configfile string
59         cmd        *exec.Cmd
60         output     bytes.Buffer // holds stderr from sshd process
61
62         // Client half of the network connection.
63         clientConn net.Conn
64 }
65
66 func username() string {
67         var username string
68         if user, err := user.Current(); err == nil {
69                 username = user.Username
70         } else {
71                 // user.Current() currently requires cgo. If an error is
72                 // returned attempt to get the username from the environment.
73                 log.Printf("user.Current: %v; falling back on $USER", err)
74                 username = os.Getenv("USER")
75         }
76         if username == "" {
77                 panic("Unable to get username")
78         }
79         return username
80 }
81
82 type storedHostKey struct {
83         // keys map from an algorithm string to binary key data.
84         keys map[string][]byte
85
86         // checkCount counts the Check calls. Used for testing
87         // rekeying.
88         checkCount int
89 }
90
91 func (k *storedHostKey) Add(key ssh.PublicKey) {
92         if k.keys == nil {
93                 k.keys = map[string][]byte{}
94         }
95         k.keys[key.Type()] = key.Marshal()
96 }
97
98 func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
99         k.checkCount++
100         algo := key.Type()
101
102         if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
103                 return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
104         }
105         return nil
106 }
107
108 func hostKeyDB() *storedHostKey {
109         keyChecker := &storedHostKey{}
110         keyChecker.Add(testPublicKeys["ecdsa"])
111         keyChecker.Add(testPublicKeys["rsa"])
112         keyChecker.Add(testPublicKeys["dsa"])
113         return keyChecker
114 }
115
116 func clientConfig() *ssh.ClientConfig {
117         config := &ssh.ClientConfig{
118                 User: username(),
119                 Auth: []ssh.AuthMethod{
120                         ssh.PublicKeys(testSigners["user"]),
121                 },
122                 HostKeyCallback: hostKeyDB().Check,
123                 HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
124                         ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
125                         ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
126                         ssh.KeyAlgoED25519,
127                 },
128         }
129         return config
130 }
131
132 // unixConnection creates two halves of a connected net.UnixConn.  It
133 // is used for connecting the Go SSH client with sshd without opening
134 // ports.
135 func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
136         dir, err := ioutil.TempDir("", "unixConnection")
137         if err != nil {
138                 return nil, nil, err
139         }
140         defer os.Remove(dir)
141
142         addr := filepath.Join(dir, "ssh")
143         listener, err := net.Listen("unix", addr)
144         if err != nil {
145                 return nil, nil, err
146         }
147         defer listener.Close()
148         c1, err := net.Dial("unix", addr)
149         if err != nil {
150                 return nil, nil, err
151         }
152
153         c2, err := listener.Accept()
154         if err != nil {
155                 c1.Close()
156                 return nil, nil, err
157         }
158
159         return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
160 }
161
162 func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
163         return s.TryDialWithAddr(config, "")
164 }
165
166 // addr is the user specified host:port. While we don't actually dial it,
167 // we need to know this for host key matching
168 func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
169         sshd, err := exec.LookPath("sshd")
170         if err != nil {
171                 s.t.Skipf("skipping test: %v", err)
172         }
173
174         c1, c2, err := unixConnection()
175         if err != nil {
176                 s.t.Fatalf("unixConnection: %v", err)
177         }
178
179         s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
180         f, err := c2.File()
181         if err != nil {
182                 s.t.Fatalf("UnixConn.File: %v", err)
183         }
184         defer f.Close()
185         s.cmd.Stdin = f
186         s.cmd.Stdout = f
187         s.cmd.Stderr = &s.output
188         if err := s.cmd.Start(); err != nil {
189                 s.t.Fail()
190                 s.Shutdown()
191                 s.t.Fatalf("s.cmd.Start: %v", err)
192         }
193         s.clientConn = c1
194         conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
195         if err != nil {
196                 return nil, err
197         }
198         return ssh.NewClient(conn, chans, reqs), nil
199 }
200
201 func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
202         conn, err := s.TryDial(config)
203         if err != nil {
204                 s.t.Fail()
205                 s.Shutdown()
206                 s.t.Fatalf("ssh.Client: %v", err)
207         }
208         return conn
209 }
210
211 func (s *server) Shutdown() {
212         if s.cmd != nil && s.cmd.Process != nil {
213                 // Don't check for errors; if it fails it's most
214                 // likely "os: process already finished", and we don't
215                 // care about that. Use os.Interrupt, so child
216                 // processes are killed too.
217                 s.cmd.Process.Signal(os.Interrupt)
218                 s.cmd.Wait()
219         }
220         if s.t.Failed() {
221                 // log any output from sshd process
222                 s.t.Logf("sshd: %s", s.output.String())
223         }
224         s.cleanup()
225 }
226
227 func writeFile(path string, contents []byte) {
228         f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
229         if err != nil {
230                 panic(err)
231         }
232         defer f.Close()
233         if _, err := f.Write(contents); err != nil {
234                 panic(err)
235         }
236 }
237
238 // newServer returns a new mock ssh server.
239 func newServer(t *testing.T) *server {
240         if testing.Short() {
241                 t.Skip("skipping test due to -short")
242         }
243         dir, err := ioutil.TempDir("", "sshtest")
244         if err != nil {
245                 t.Fatal(err)
246         }
247         f, err := os.Create(filepath.Join(dir, "sshd_config"))
248         if err != nil {
249                 t.Fatal(err)
250         }
251         err = configTmpl.Execute(f, map[string]string{
252                 "Dir": dir,
253         })
254         if err != nil {
255                 t.Fatal(err)
256         }
257         f.Close()
258
259         for k, v := range testdata.PEMBytes {
260                 filename := "id_" + k
261                 writeFile(filepath.Join(dir, filename), v)
262                 writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
263         }
264
265         for k, v := range testdata.SSHCertificates {
266                 filename := "id_" + k + "-cert.pub"
267                 writeFile(filepath.Join(dir, filename), v)
268         }
269
270         var authkeys bytes.Buffer
271         for k, _ := range testdata.PEMBytes {
272                 authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
273         }
274         writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
275
276         return &server{
277                 t:          t,
278                 configfile: f.Name(),
279                 cleanup: func() {
280                         if err := os.RemoveAll(dir); err != nil {
281                                 t.Error(err)
282                         }
283                 },
284         }
285 }
286
287 func newTempSocket(t *testing.T) (string, func()) {
288         dir, err := ioutil.TempDir("", "socket")
289         if err != nil {
290                 t.Fatal(err)
291         }
292         deferFunc := func() { os.RemoveAll(dir) }
293         addr := filepath.Join(dir, "sock")
294         return addr, deferFunc
295 }