// Copyright 2012 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements the Socialist Millionaires Protocol as described in // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol // specification is required in order to understand this code and, where // possible, the variable names in the code match up with the spec. package otr import ( "bytes" "crypto/sha256" "errors" "hash" "math/big" ) type smpFailure string func (s smpFailure) Error() string { return string(s) } var smpFailureError = smpFailure("otr: SMP protocol failed") var smpSecretMissingError = smpFailure("otr: mutual secret needed") const smpVersion = 1 const ( smpState1 = iota smpState2 smpState3 smpState4 ) type smpState struct { state int a2, a3, b2, b3, pb, qb *big.Int g2a, g3a *big.Int g2, g3 *big.Int g3b, papb, qaqb, ra *big.Int saved *tlv secret *big.Int question string } func (c *Conversation) startSMP(question string) (tlvs []tlv) { if c.smp.state != smpState1 { tlvs = append(tlvs, c.generateSMPAbort()) } tlvs = append(tlvs, c.generateSMP1(question)) c.smp.question = "" c.smp.state = smpState2 return } func (c *Conversation) resetSMP() { c.smp.state = smpState1 c.smp.secret = nil c.smp.question = "" } func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) { data := in.data switch in.typ { case tlvTypeSMPAbort: if c.smp.state != smpState1 { err = smpFailureError } c.resetSMP() return case tlvTypeSMP1WithQuestion: // We preprocess this into a SMP1 message. nulPos := bytes.IndexByte(data, 0) if nulPos == -1 { err = errors.New("otr: SMP message with question didn't contain a NUL byte") return } c.smp.question = string(data[:nulPos]) data = data[nulPos+1:] } numMPIs, data, ok := getU32(data) if !ok || numMPIs > 20 { err = errors.New("otr: corrupt SMP message") return } mpis := make([]*big.Int, numMPIs) for i := range mpis { var ok bool mpis[i], data, ok = getMPI(data) if !ok { err = errors.New("otr: corrupt SMP message") return } } switch in.typ { case tlvTypeSMP1, tlvTypeSMP1WithQuestion: if c.smp.state != smpState1 { c.resetSMP() out = c.generateSMPAbort() return } if c.smp.secret == nil { err = smpSecretMissingError return } if err = c.processSMP1(mpis); err != nil { return } c.smp.state = smpState3 out = c.generateSMP2() case tlvTypeSMP2: if c.smp.state != smpState2 { c.resetSMP() out = c.generateSMPAbort() return } if out, err = c.processSMP2(mpis); err != nil { out = c.generateSMPAbort() return } c.smp.state = smpState4 case tlvTypeSMP3: if c.smp.state != smpState3 { c.resetSMP() out = c.generateSMPAbort() return } if out, err = c.processSMP3(mpis); err != nil { return } c.smp.state = smpState1 c.smp.secret = nil complete = true case tlvTypeSMP4: if c.smp.state != smpState4 { c.resetSMP() out = c.generateSMPAbort() return } if err = c.processSMP4(mpis); err != nil { out = c.generateSMPAbort() return } c.smp.state = smpState1 c.smp.secret = nil complete = true default: panic("unknown SMP message") } return } func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) { h := sha256.New() h.Write([]byte{smpVersion}) if weStarted { h.Write(c.PrivateKey.PublicKey.Fingerprint()) h.Write(c.TheirPublicKey.Fingerprint()) } else { h.Write(c.TheirPublicKey.Fingerprint()) h.Write(c.PrivateKey.PublicKey.Fingerprint()) } h.Write(c.SSID[:]) h.Write(mutualSecret) c.smp.secret = new(big.Int).SetBytes(h.Sum(nil)) } func (c *Conversation) generateSMP1(question string) tlv { var randBuf [16]byte c.smp.a2 = c.randMPI(randBuf[:]) c.smp.a3 = c.randMPI(randBuf[:]) g2a := new(big.Int).Exp(g, c.smp.a2, p) g3a := new(big.Int).Exp(g, c.smp.a3, p) h := sha256.New() r2 := c.randMPI(randBuf[:]) r := new(big.Int).Exp(g, r2, p) c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r)) d2 := new(big.Int).Mul(c.smp.a2, c2) d2.Sub(r2, d2) d2.Mod(d2, q) if d2.Sign() < 0 { d2.Add(d2, q) } r3 := c.randMPI(randBuf[:]) r.Exp(g, r3, p) c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r)) d3 := new(big.Int).Mul(c.smp.a3, c3) d3.Sub(r3, d3) d3.Mod(d3, q) if d3.Sign() < 0 { d3.Add(d3, q) } var ret tlv if len(question) > 0 { ret.typ = tlvTypeSMP1WithQuestion ret.data = append(ret.data, question...) ret.data = append(ret.data, 0) } else { ret.typ = tlvTypeSMP1 } ret.data = appendU32(ret.data, 6) ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3) return ret } func (c *Conversation) processSMP1(mpis []*big.Int) error { if len(mpis) != 6 { return errors.New("otr: incorrect number of arguments in SMP1 message") } g2a := mpis[0] c2 := mpis[1] d2 := mpis[2] g3a := mpis[3] c3 := mpis[4] d3 := mpis[5] h := sha256.New() r := new(big.Int).Exp(g, d2, p) s := new(big.Int).Exp(g2a, c2, p) r.Mul(r, s) r.Mod(r, p) t := new(big.Int).SetBytes(hashMPIs(h, 1, r)) if c2.Cmp(t) != 0 { return errors.New("otr: ZKP c2 incorrect in SMP1 message") } r.Exp(g, d3, p) s.Exp(g3a, c3, p) r.Mul(r, s) r.Mod(r, p) t.SetBytes(hashMPIs(h, 2, r)) if c3.Cmp(t) != 0 { return errors.New("otr: ZKP c3 incorrect in SMP1 message") } c.smp.g2a = g2a c.smp.g3a = g3a return nil } func (c *Conversation) generateSMP2() tlv { var randBuf [16]byte b2 := c.randMPI(randBuf[:]) c.smp.b3 = c.randMPI(randBuf[:]) r2 := c.randMPI(randBuf[:]) r3 := c.randMPI(randBuf[:]) r4 := c.randMPI(randBuf[:]) r5 := c.randMPI(randBuf[:]) r6 := c.randMPI(randBuf[:]) g2b := new(big.Int).Exp(g, b2, p) g3b := new(big.Int).Exp(g, c.smp.b3, p) r := new(big.Int).Exp(g, r2, p) h := sha256.New() c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r)) d2 := new(big.Int).Mul(b2, c2) d2.Sub(r2, d2) d2.Mod(d2, q) if d2.Sign() < 0 { d2.Add(d2, q) } r.Exp(g, r3, p) c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r)) d3 := new(big.Int).Mul(c.smp.b3, c3) d3.Sub(r3, d3) d3.Mod(d3, q) if d3.Sign() < 0 { d3.Add(d3, q) } c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p) c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p) c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p) c.smp.qb = new(big.Int).Exp(g, r4, p) r.Exp(c.smp.g2, c.smp.secret, p) c.smp.qb.Mul(c.smp.qb, r) c.smp.qb.Mod(c.smp.qb, p) s := new(big.Int) s.Exp(c.smp.g2, r6, p) r.Exp(g, r5, p) s.Mul(r, s) s.Mod(s, p) r.Exp(c.smp.g3, r5, p) cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s)) // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q s.Mul(r4, cp) r.Sub(r5, s) d5 := new(big.Int).Mod(r, q) if d5.Sign() < 0 { d5.Add(d5, q) } s.Mul(c.smp.secret, cp) r.Sub(r6, s) d6 := new(big.Int).Mod(r, q) if d6.Sign() < 0 { d6.Add(d6, q) } var ret tlv ret.typ = tlvTypeSMP2 ret.data = appendU32(ret.data, 11) ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6) return ret } func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) { if len(mpis) != 11 { err = errors.New("otr: incorrect number of arguments in SMP2 message") return } g2b := mpis[0] c2 := mpis[1] d2 := mpis[2] g3b := mpis[3] c3 := mpis[4] d3 := mpis[5] pb := mpis[6] qb := mpis[7] cp := mpis[8] d5 := mpis[9] d6 := mpis[10] h := sha256.New() r := new(big.Int).Exp(g, d2, p) s := new(big.Int).Exp(g2b, c2, p) r.Mul(r, s) r.Mod(r, p) s.SetBytes(hashMPIs(h, 3, r)) if c2.Cmp(s) != 0 { err = errors.New("otr: ZKP c2 failed in SMP2 message") return } r.Exp(g, d3, p) s.Exp(g3b, c3, p) r.Mul(r, s) r.Mod(r, p) s.SetBytes(hashMPIs(h, 4, r)) if c3.Cmp(s) != 0 { err = errors.New("otr: ZKP c3 failed in SMP2 message") return } c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p) c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p) r.Exp(g, d5, p) s.Exp(c.smp.g2, d6, p) r.Mul(r, s) s.Exp(qb, cp, p) r.Mul(r, s) r.Mod(r, p) s.Exp(c.smp.g3, d5, p) t := new(big.Int).Exp(pb, cp, p) s.Mul(s, t) s.Mod(s, p) t.SetBytes(hashMPIs(h, 5, s, r)) if cp.Cmp(t) != 0 { err = errors.New("otr: ZKP cP failed in SMP2 message") return } var randBuf [16]byte r4 := c.randMPI(randBuf[:]) r5 := c.randMPI(randBuf[:]) r6 := c.randMPI(randBuf[:]) r7 := c.randMPI(randBuf[:]) pa := new(big.Int).Exp(c.smp.g3, r4, p) r.Exp(c.smp.g2, c.smp.secret, p) qa := new(big.Int).Exp(g, r4, p) qa.Mul(qa, r) qa.Mod(qa, p) r.Exp(g, r5, p) s.Exp(c.smp.g2, r6, p) r.Mul(r, s) r.Mod(r, p) s.Exp(c.smp.g3, r5, p) cp.SetBytes(hashMPIs(h, 6, s, r)) r.Mul(r4, cp) d5 = new(big.Int).Sub(r5, r) d5.Mod(d5, q) if d5.Sign() < 0 { d5.Add(d5, q) } r.Mul(c.smp.secret, cp) d6 = new(big.Int).Sub(r6, r) d6.Mod(d6, q) if d6.Sign() < 0 { d6.Add(d6, q) } r.ModInverse(qb, p) qaqb := new(big.Int).Mul(qa, r) qaqb.Mod(qaqb, p) ra := new(big.Int).Exp(qaqb, c.smp.a3, p) r.Exp(qaqb, r7, p) s.Exp(g, r7, p) cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r)) r.Mul(c.smp.a3, cr) d7 := new(big.Int).Sub(r7, r) d7.Mod(d7, q) if d7.Sign() < 0 { d7.Add(d7, q) } c.smp.g3b = g3b c.smp.qaqb = qaqb r.ModInverse(pb, p) c.smp.papb = new(big.Int).Mul(pa, r) c.smp.papb.Mod(c.smp.papb, p) c.smp.ra = ra out.typ = tlvTypeSMP3 out.data = appendU32(out.data, 8) out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7) return } func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) { if len(mpis) != 8 { err = errors.New("otr: incorrect number of arguments in SMP3 message") return } pa := mpis[0] qa := mpis[1] cp := mpis[2] d5 := mpis[3] d6 := mpis[4] ra := mpis[5] cr := mpis[6] d7 := mpis[7] h := sha256.New() r := new(big.Int).Exp(g, d5, p) s := new(big.Int).Exp(c.smp.g2, d6, p) r.Mul(r, s) s.Exp(qa, cp, p) r.Mul(r, s) r.Mod(r, p) s.Exp(c.smp.g3, d5, p) t := new(big.Int).Exp(pa, cp, p) s.Mul(s, t) s.Mod(s, p) t.SetBytes(hashMPIs(h, 6, s, r)) if t.Cmp(cp) != 0 { err = errors.New("otr: ZKP cP failed in SMP3 message") return } r.ModInverse(c.smp.qb, p) qaqb := new(big.Int).Mul(qa, r) qaqb.Mod(qaqb, p) r.Exp(qaqb, d7, p) s.Exp(ra, cr, p) r.Mul(r, s) r.Mod(r, p) s.Exp(g, d7, p) t.Exp(c.smp.g3a, cr, p) s.Mul(s, t) s.Mod(s, p) t.SetBytes(hashMPIs(h, 7, s, r)) if t.Cmp(cr) != 0 { err = errors.New("otr: ZKP cR failed in SMP3 message") return } var randBuf [16]byte r7 := c.randMPI(randBuf[:]) rb := new(big.Int).Exp(qaqb, c.smp.b3, p) r.Exp(qaqb, r7, p) s.Exp(g, r7, p) cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r)) r.Mul(c.smp.b3, cr) d7 = new(big.Int).Sub(r7, r) d7.Mod(d7, q) if d7.Sign() < 0 { d7.Add(d7, q) } out.typ = tlvTypeSMP4 out.data = appendU32(out.data, 3) out.data = appendMPIs(out.data, rb, cr, d7) r.ModInverse(c.smp.pb, p) r.Mul(pa, r) r.Mod(r, p) s.Exp(ra, c.smp.b3, p) if r.Cmp(s) != 0 { err = smpFailureError } return } func (c *Conversation) processSMP4(mpis []*big.Int) error { if len(mpis) != 3 { return errors.New("otr: incorrect number of arguments in SMP4 message") } rb := mpis[0] cr := mpis[1] d7 := mpis[2] h := sha256.New() r := new(big.Int).Exp(c.smp.qaqb, d7, p) s := new(big.Int).Exp(rb, cr, p) r.Mul(r, s) r.Mod(r, p) s.Exp(g, d7, p) t := new(big.Int).Exp(c.smp.g3b, cr, p) s.Mul(s, t) s.Mod(s, p) t.SetBytes(hashMPIs(h, 8, s, r)) if t.Cmp(cr) != 0 { return errors.New("otr: ZKP cR failed in SMP4 message") } r.Exp(rb, c.smp.a3, p) if r.Cmp(c.smp.papb) != 0 { return smpFailureError } return nil } func (c *Conversation) generateSMPAbort() tlv { return tlv{typ: tlvTypeSMPAbort} } func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte { if h != nil { h.Reset() } else { h = sha256.New() } h.Write([]byte{magic}) for _, mpi := range mpis { h.Write(appendMPI(nil, mpi)) } return h.Sum(nil) }