OSDN Git Service

Hulk did something
[bytom/vapor.git] / vendor / golang.org / x / crypto / otr / smp.go
1 // Copyright 2012 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 // This file implements the Socialist Millionaires Protocol as described in
6 // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
7 // specification is required in order to understand this code and, where
8 // possible, the variable names in the code match up with the spec.
9
10 package otr
11
12 import (
13         "bytes"
14         "crypto/sha256"
15         "errors"
16         "hash"
17         "math/big"
18 )
19
20 type smpFailure string
21
22 func (s smpFailure) Error() string {
23         return string(s)
24 }
25
26 var smpFailureError = smpFailure("otr: SMP protocol failed")
27 var smpSecretMissingError = smpFailure("otr: mutual secret needed")
28
29 const smpVersion = 1
30
31 const (
32         smpState1 = iota
33         smpState2
34         smpState3
35         smpState4
36 )
37
38 type smpState struct {
39         state                  int
40         a2, a3, b2, b3, pb, qb *big.Int
41         g2a, g3a               *big.Int
42         g2, g3                 *big.Int
43         g3b, papb, qaqb, ra    *big.Int
44         saved                  *tlv
45         secret                 *big.Int
46         question               string
47 }
48
49 func (c *Conversation) startSMP(question string) (tlvs []tlv) {
50         if c.smp.state != smpState1 {
51                 tlvs = append(tlvs, c.generateSMPAbort())
52         }
53         tlvs = append(tlvs, c.generateSMP1(question))
54         c.smp.question = ""
55         c.smp.state = smpState2
56         return
57 }
58
59 func (c *Conversation) resetSMP() {
60         c.smp.state = smpState1
61         c.smp.secret = nil
62         c.smp.question = ""
63 }
64
65 func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
66         data := in.data
67
68         switch in.typ {
69         case tlvTypeSMPAbort:
70                 if c.smp.state != smpState1 {
71                         err = smpFailureError
72                 }
73                 c.resetSMP()
74                 return
75         case tlvTypeSMP1WithQuestion:
76                 // We preprocess this into a SMP1 message.
77                 nulPos := bytes.IndexByte(data, 0)
78                 if nulPos == -1 {
79                         err = errors.New("otr: SMP message with question didn't contain a NUL byte")
80                         return
81                 }
82                 c.smp.question = string(data[:nulPos])
83                 data = data[nulPos+1:]
84         }
85
86         numMPIs, data, ok := getU32(data)
87         if !ok || numMPIs > 20 {
88                 err = errors.New("otr: corrupt SMP message")
89                 return
90         }
91
92         mpis := make([]*big.Int, numMPIs)
93         for i := range mpis {
94                 var ok bool
95                 mpis[i], data, ok = getMPI(data)
96                 if !ok {
97                         err = errors.New("otr: corrupt SMP message")
98                         return
99                 }
100         }
101
102         switch in.typ {
103         case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
104                 if c.smp.state != smpState1 {
105                         c.resetSMP()
106                         out = c.generateSMPAbort()
107                         return
108                 }
109                 if c.smp.secret == nil {
110                         err = smpSecretMissingError
111                         return
112                 }
113                 if err = c.processSMP1(mpis); err != nil {
114                         return
115                 }
116                 c.smp.state = smpState3
117                 out = c.generateSMP2()
118         case tlvTypeSMP2:
119                 if c.smp.state != smpState2 {
120                         c.resetSMP()
121                         out = c.generateSMPAbort()
122                         return
123                 }
124                 if out, err = c.processSMP2(mpis); err != nil {
125                         out = c.generateSMPAbort()
126                         return
127                 }
128                 c.smp.state = smpState4
129         case tlvTypeSMP3:
130                 if c.smp.state != smpState3 {
131                         c.resetSMP()
132                         out = c.generateSMPAbort()
133                         return
134                 }
135                 if out, err = c.processSMP3(mpis); err != nil {
136                         return
137                 }
138                 c.smp.state = smpState1
139                 c.smp.secret = nil
140                 complete = true
141         case tlvTypeSMP4:
142                 if c.smp.state != smpState4 {
143                         c.resetSMP()
144                         out = c.generateSMPAbort()
145                         return
146                 }
147                 if err = c.processSMP4(mpis); err != nil {
148                         out = c.generateSMPAbort()
149                         return
150                 }
151                 c.smp.state = smpState1
152                 c.smp.secret = nil
153                 complete = true
154         default:
155                 panic("unknown SMP message")
156         }
157
158         return
159 }
160
161 func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
162         h := sha256.New()
163         h.Write([]byte{smpVersion})
164         if weStarted {
165                 h.Write(c.PrivateKey.PublicKey.Fingerprint())
166                 h.Write(c.TheirPublicKey.Fingerprint())
167         } else {
168                 h.Write(c.TheirPublicKey.Fingerprint())
169                 h.Write(c.PrivateKey.PublicKey.Fingerprint())
170         }
171         h.Write(c.SSID[:])
172         h.Write(mutualSecret)
173         c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
174 }
175
176 func (c *Conversation) generateSMP1(question string) tlv {
177         var randBuf [16]byte
178         c.smp.a2 = c.randMPI(randBuf[:])
179         c.smp.a3 = c.randMPI(randBuf[:])
180         g2a := new(big.Int).Exp(g, c.smp.a2, p)
181         g3a := new(big.Int).Exp(g, c.smp.a3, p)
182         h := sha256.New()
183
184         r2 := c.randMPI(randBuf[:])
185         r := new(big.Int).Exp(g, r2, p)
186         c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
187         d2 := new(big.Int).Mul(c.smp.a2, c2)
188         d2.Sub(r2, d2)
189         d2.Mod(d2, q)
190         if d2.Sign() < 0 {
191                 d2.Add(d2, q)
192         }
193
194         r3 := c.randMPI(randBuf[:])
195         r.Exp(g, r3, p)
196         c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
197         d3 := new(big.Int).Mul(c.smp.a3, c3)
198         d3.Sub(r3, d3)
199         d3.Mod(d3, q)
200         if d3.Sign() < 0 {
201                 d3.Add(d3, q)
202         }
203
204         var ret tlv
205         if len(question) > 0 {
206                 ret.typ = tlvTypeSMP1WithQuestion
207                 ret.data = append(ret.data, question...)
208                 ret.data = append(ret.data, 0)
209         } else {
210                 ret.typ = tlvTypeSMP1
211         }
212         ret.data = appendU32(ret.data, 6)
213         ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
214         return ret
215 }
216
217 func (c *Conversation) processSMP1(mpis []*big.Int) error {
218         if len(mpis) != 6 {
219                 return errors.New("otr: incorrect number of arguments in SMP1 message")
220         }
221         g2a := mpis[0]
222         c2 := mpis[1]
223         d2 := mpis[2]
224         g3a := mpis[3]
225         c3 := mpis[4]
226         d3 := mpis[5]
227         h := sha256.New()
228
229         r := new(big.Int).Exp(g, d2, p)
230         s := new(big.Int).Exp(g2a, c2, p)
231         r.Mul(r, s)
232         r.Mod(r, p)
233         t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
234         if c2.Cmp(t) != 0 {
235                 return errors.New("otr: ZKP c2 incorrect in SMP1 message")
236         }
237         r.Exp(g, d3, p)
238         s.Exp(g3a, c3, p)
239         r.Mul(r, s)
240         r.Mod(r, p)
241         t.SetBytes(hashMPIs(h, 2, r))
242         if c3.Cmp(t) != 0 {
243                 return errors.New("otr: ZKP c3 incorrect in SMP1 message")
244         }
245
246         c.smp.g2a = g2a
247         c.smp.g3a = g3a
248         return nil
249 }
250
251 func (c *Conversation) generateSMP2() tlv {
252         var randBuf [16]byte
253         b2 := c.randMPI(randBuf[:])
254         c.smp.b3 = c.randMPI(randBuf[:])
255         r2 := c.randMPI(randBuf[:])
256         r3 := c.randMPI(randBuf[:])
257         r4 := c.randMPI(randBuf[:])
258         r5 := c.randMPI(randBuf[:])
259         r6 := c.randMPI(randBuf[:])
260
261         g2b := new(big.Int).Exp(g, b2, p)
262         g3b := new(big.Int).Exp(g, c.smp.b3, p)
263
264         r := new(big.Int).Exp(g, r2, p)
265         h := sha256.New()
266         c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
267         d2 := new(big.Int).Mul(b2, c2)
268         d2.Sub(r2, d2)
269         d2.Mod(d2, q)
270         if d2.Sign() < 0 {
271                 d2.Add(d2, q)
272         }
273
274         r.Exp(g, r3, p)
275         c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
276         d3 := new(big.Int).Mul(c.smp.b3, c3)
277         d3.Sub(r3, d3)
278         d3.Mod(d3, q)
279         if d3.Sign() < 0 {
280                 d3.Add(d3, q)
281         }
282
283         c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
284         c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
285         c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
286         c.smp.qb = new(big.Int).Exp(g, r4, p)
287         r.Exp(c.smp.g2, c.smp.secret, p)
288         c.smp.qb.Mul(c.smp.qb, r)
289         c.smp.qb.Mod(c.smp.qb, p)
290
291         s := new(big.Int)
292         s.Exp(c.smp.g2, r6, p)
293         r.Exp(g, r5, p)
294         s.Mul(r, s)
295         s.Mod(s, p)
296         r.Exp(c.smp.g3, r5, p)
297         cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
298
299         // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
300
301         s.Mul(r4, cp)
302         r.Sub(r5, s)
303         d5 := new(big.Int).Mod(r, q)
304         if d5.Sign() < 0 {
305                 d5.Add(d5, q)
306         }
307
308         s.Mul(c.smp.secret, cp)
309         r.Sub(r6, s)
310         d6 := new(big.Int).Mod(r, q)
311         if d6.Sign() < 0 {
312                 d6.Add(d6, q)
313         }
314
315         var ret tlv
316         ret.typ = tlvTypeSMP2
317         ret.data = appendU32(ret.data, 11)
318         ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
319         return ret
320 }
321
322 func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
323         if len(mpis) != 11 {
324                 err = errors.New("otr: incorrect number of arguments in SMP2 message")
325                 return
326         }
327         g2b := mpis[0]
328         c2 := mpis[1]
329         d2 := mpis[2]
330         g3b := mpis[3]
331         c3 := mpis[4]
332         d3 := mpis[5]
333         pb := mpis[6]
334         qb := mpis[7]
335         cp := mpis[8]
336         d5 := mpis[9]
337         d6 := mpis[10]
338         h := sha256.New()
339
340         r := new(big.Int).Exp(g, d2, p)
341         s := new(big.Int).Exp(g2b, c2, p)
342         r.Mul(r, s)
343         r.Mod(r, p)
344         s.SetBytes(hashMPIs(h, 3, r))
345         if c2.Cmp(s) != 0 {
346                 err = errors.New("otr: ZKP c2 failed in SMP2 message")
347                 return
348         }
349
350         r.Exp(g, d3, p)
351         s.Exp(g3b, c3, p)
352         r.Mul(r, s)
353         r.Mod(r, p)
354         s.SetBytes(hashMPIs(h, 4, r))
355         if c3.Cmp(s) != 0 {
356                 err = errors.New("otr: ZKP c3 failed in SMP2 message")
357                 return
358         }
359
360         c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
361         c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
362
363         r.Exp(g, d5, p)
364         s.Exp(c.smp.g2, d6, p)
365         r.Mul(r, s)
366         s.Exp(qb, cp, p)
367         r.Mul(r, s)
368         r.Mod(r, p)
369
370         s.Exp(c.smp.g3, d5, p)
371         t := new(big.Int).Exp(pb, cp, p)
372         s.Mul(s, t)
373         s.Mod(s, p)
374         t.SetBytes(hashMPIs(h, 5, s, r))
375         if cp.Cmp(t) != 0 {
376                 err = errors.New("otr: ZKP cP failed in SMP2 message")
377                 return
378         }
379
380         var randBuf [16]byte
381         r4 := c.randMPI(randBuf[:])
382         r5 := c.randMPI(randBuf[:])
383         r6 := c.randMPI(randBuf[:])
384         r7 := c.randMPI(randBuf[:])
385
386         pa := new(big.Int).Exp(c.smp.g3, r4, p)
387         r.Exp(c.smp.g2, c.smp.secret, p)
388         qa := new(big.Int).Exp(g, r4, p)
389         qa.Mul(qa, r)
390         qa.Mod(qa, p)
391
392         r.Exp(g, r5, p)
393         s.Exp(c.smp.g2, r6, p)
394         r.Mul(r, s)
395         r.Mod(r, p)
396
397         s.Exp(c.smp.g3, r5, p)
398         cp.SetBytes(hashMPIs(h, 6, s, r))
399
400         r.Mul(r4, cp)
401         d5 = new(big.Int).Sub(r5, r)
402         d5.Mod(d5, q)
403         if d5.Sign() < 0 {
404                 d5.Add(d5, q)
405         }
406
407         r.Mul(c.smp.secret, cp)
408         d6 = new(big.Int).Sub(r6, r)
409         d6.Mod(d6, q)
410         if d6.Sign() < 0 {
411                 d6.Add(d6, q)
412         }
413
414         r.ModInverse(qb, p)
415         qaqb := new(big.Int).Mul(qa, r)
416         qaqb.Mod(qaqb, p)
417
418         ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
419         r.Exp(qaqb, r7, p)
420         s.Exp(g, r7, p)
421         cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
422
423         r.Mul(c.smp.a3, cr)
424         d7 := new(big.Int).Sub(r7, r)
425         d7.Mod(d7, q)
426         if d7.Sign() < 0 {
427                 d7.Add(d7, q)
428         }
429
430         c.smp.g3b = g3b
431         c.smp.qaqb = qaqb
432
433         r.ModInverse(pb, p)
434         c.smp.papb = new(big.Int).Mul(pa, r)
435         c.smp.papb.Mod(c.smp.papb, p)
436         c.smp.ra = ra
437
438         out.typ = tlvTypeSMP3
439         out.data = appendU32(out.data, 8)
440         out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
441         return
442 }
443
444 func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
445         if len(mpis) != 8 {
446                 err = errors.New("otr: incorrect number of arguments in SMP3 message")
447                 return
448         }
449         pa := mpis[0]
450         qa := mpis[1]
451         cp := mpis[2]
452         d5 := mpis[3]
453         d6 := mpis[4]
454         ra := mpis[5]
455         cr := mpis[6]
456         d7 := mpis[7]
457         h := sha256.New()
458
459         r := new(big.Int).Exp(g, d5, p)
460         s := new(big.Int).Exp(c.smp.g2, d6, p)
461         r.Mul(r, s)
462         s.Exp(qa, cp, p)
463         r.Mul(r, s)
464         r.Mod(r, p)
465
466         s.Exp(c.smp.g3, d5, p)
467         t := new(big.Int).Exp(pa, cp, p)
468         s.Mul(s, t)
469         s.Mod(s, p)
470         t.SetBytes(hashMPIs(h, 6, s, r))
471         if t.Cmp(cp) != 0 {
472                 err = errors.New("otr: ZKP cP failed in SMP3 message")
473                 return
474         }
475
476         r.ModInverse(c.smp.qb, p)
477         qaqb := new(big.Int).Mul(qa, r)
478         qaqb.Mod(qaqb, p)
479
480         r.Exp(qaqb, d7, p)
481         s.Exp(ra, cr, p)
482         r.Mul(r, s)
483         r.Mod(r, p)
484
485         s.Exp(g, d7, p)
486         t.Exp(c.smp.g3a, cr, p)
487         s.Mul(s, t)
488         s.Mod(s, p)
489         t.SetBytes(hashMPIs(h, 7, s, r))
490         if t.Cmp(cr) != 0 {
491                 err = errors.New("otr: ZKP cR failed in SMP3 message")
492                 return
493         }
494
495         var randBuf [16]byte
496         r7 := c.randMPI(randBuf[:])
497         rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
498
499         r.Exp(qaqb, r7, p)
500         s.Exp(g, r7, p)
501         cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
502
503         r.Mul(c.smp.b3, cr)
504         d7 = new(big.Int).Sub(r7, r)
505         d7.Mod(d7, q)
506         if d7.Sign() < 0 {
507                 d7.Add(d7, q)
508         }
509
510         out.typ = tlvTypeSMP4
511         out.data = appendU32(out.data, 3)
512         out.data = appendMPIs(out.data, rb, cr, d7)
513
514         r.ModInverse(c.smp.pb, p)
515         r.Mul(pa, r)
516         r.Mod(r, p)
517         s.Exp(ra, c.smp.b3, p)
518         if r.Cmp(s) != 0 {
519                 err = smpFailureError
520         }
521
522         return
523 }
524
525 func (c *Conversation) processSMP4(mpis []*big.Int) error {
526         if len(mpis) != 3 {
527                 return errors.New("otr: incorrect number of arguments in SMP4 message")
528         }
529         rb := mpis[0]
530         cr := mpis[1]
531         d7 := mpis[2]
532         h := sha256.New()
533
534         r := new(big.Int).Exp(c.smp.qaqb, d7, p)
535         s := new(big.Int).Exp(rb, cr, p)
536         r.Mul(r, s)
537         r.Mod(r, p)
538
539         s.Exp(g, d7, p)
540         t := new(big.Int).Exp(c.smp.g3b, cr, p)
541         s.Mul(s, t)
542         s.Mod(s, p)
543         t.SetBytes(hashMPIs(h, 8, s, r))
544         if t.Cmp(cr) != 0 {
545                 return errors.New("otr: ZKP cR failed in SMP4 message")
546         }
547
548         r.Exp(rb, c.smp.a3, p)
549         if r.Cmp(c.smp.papb) != 0 {
550                 return smpFailureError
551         }
552
553         return nil
554 }
555
556 func (c *Conversation) generateSMPAbort() tlv {
557         return tlv{typ: tlvTypeSMPAbort}
558 }
559
560 func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
561         if h != nil {
562                 h.Reset()
563         } else {
564                 h = sha256.New()
565         }
566
567         h.Write([]byte{magic})
568         for _, mpi := range mpis {
569                 h.Write(appendMPI(nil, mpi))
570         }
571         return h.Sum(nil)
572 }