OSDN Git Service

23f6fc24247ebb96d74d1b1b087cfb577de9f12d
[bytom/bytom.git] / crypto / ed25519 / chainkd / chainkd.go
1 package chainkd
2
3 import (
4         "crypto/hmac"
5         "crypto/rand"
6         "crypto/sha512"
7         "io"
8
9         "github.com/bytom/crypto/ed25519"
10         "github.com/bytom/crypto/ed25519/ecmath"
11         "bytes"
12 )
13
14 type (
15         //XPrv external private key
16         XPrv [64]byte
17         //XPub external public key
18         XPub [64]byte
19 )
20
21 // CompareTwoXPubs
22 func CompareTwoXPubs(a, b []XPub) int {
23         for i, xpub := range a {
24                 result := bytes.Compare(xpub[:], b[i][:])
25                 if result != 0 {
26                         return result
27                 }
28         }
29         return 0
30 }
31
32 // NewXPrv takes a source of random bytes and produces a new XPrv.
33 // If r is nil, crypto/rand.Reader is used.
34 func NewXPrv(r io.Reader) (xprv XPrv, err error) {
35         if r == nil {
36                 r = rand.Reader
37         }
38         var entropy [32]byte
39         _, err = io.ReadFull(r, entropy[:])
40         if err != nil {
41                 return xprv, err
42         }
43         return RootXPrv(entropy[:]), nil
44 }
45
46 // RootXPrv takes a seed binary string and produces a new xprv.
47 func RootXPrv(seed []byte) (xprv XPrv) {
48         h := hmac.New(sha512.New, []byte{'R', 'o', 'o', 't'})
49         h.Write(seed)
50         h.Sum(xprv[:0])
51         pruneRootScalar(xprv[:32])
52         return
53 }
54
55 // XPub derives an extended public key from a given xprv.
56 func (xprv XPrv) XPub() (xpub XPub) {
57         var scalar ecmath.Scalar
58         copy(scalar[:], xprv[:32])
59
60         var P ecmath.Point
61         P.ScMulBase(&scalar)
62         buf := P.Encode()
63
64         copy(xpub[:32], buf[:])
65         copy(xpub[32:], xprv[32:])
66
67         return
68 }
69
70 // Child derives a child xprv based on `selector` string and `hardened` flag.
71 // If `hardened` is false, child xpub can be derived independently
72 // from the parent xpub without using the parent xprv.
73 // If `hardened` is true, child key can only be derived from the parent xprv.
74 func (xprv XPrv) Child(sel []byte, hardened bool) XPrv {
75         if hardened {
76                 return xprv.hardenedChild(sel)
77         }
78         return xprv.nonhardenedChild(sel)
79 }
80
81 func (xprv XPrv) hardenedChild(sel []byte) (res XPrv) {
82         h := hmac.New(sha512.New, xprv[32:])
83         h.Write([]byte{'H'})
84         h.Write(xprv[:32])
85         h.Write(sel)
86         h.Sum(res[:0])
87         pruneRootScalar(res[:32])
88         return
89 }
90
91 func (xprv XPrv) nonhardenedChild(sel []byte) (res XPrv) {
92         xpub := xprv.XPub()
93
94         h := hmac.New(sha512.New, xpub[32:])
95         h.Write([]byte{'N'})
96         h.Write(xpub[:32])
97         h.Write(sel)
98         h.Sum(res[:0])
99
100         pruneIntermediateScalar(res[:32])
101
102         // Unrolled the following loop:
103         // var carry int
104         // carry = 0
105         // for i := 0; i < 32; i++ {
106         //         sum := int(xprv[i]) + int(res[i]) + carry
107         //         res[i] = byte(sum & 0xff)
108         //         carry = (sum >> 8)
109         // }
110
111         sum := int(0)
112
113         sum = int(xprv[0]) + int(res[0]) + (sum >> 8)
114         res[0] = byte(sum & 0xff)
115         sum = int(xprv[1]) + int(res[1]) + (sum >> 8)
116         res[1] = byte(sum & 0xff)
117         sum = int(xprv[2]) + int(res[2]) + (sum >> 8)
118         res[2] = byte(sum & 0xff)
119         sum = int(xprv[3]) + int(res[3]) + (sum >> 8)
120         res[3] = byte(sum & 0xff)
121         sum = int(xprv[4]) + int(res[4]) + (sum >> 8)
122         res[4] = byte(sum & 0xff)
123         sum = int(xprv[5]) + int(res[5]) + (sum >> 8)
124         res[5] = byte(sum & 0xff)
125         sum = int(xprv[6]) + int(res[6]) + (sum >> 8)
126         res[6] = byte(sum & 0xff)
127         sum = int(xprv[7]) + int(res[7]) + (sum >> 8)
128         res[7] = byte(sum & 0xff)
129         sum = int(xprv[8]) + int(res[8]) + (sum >> 8)
130         res[8] = byte(sum & 0xff)
131         sum = int(xprv[9]) + int(res[9]) + (sum >> 8)
132         res[9] = byte(sum & 0xff)
133         sum = int(xprv[10]) + int(res[10]) + (sum >> 8)
134         res[10] = byte(sum & 0xff)
135         sum = int(xprv[11]) + int(res[11]) + (sum >> 8)
136         res[11] = byte(sum & 0xff)
137         sum = int(xprv[12]) + int(res[12]) + (sum >> 8)
138         res[12] = byte(sum & 0xff)
139         sum = int(xprv[13]) + int(res[13]) + (sum >> 8)
140         res[13] = byte(sum & 0xff)
141         sum = int(xprv[14]) + int(res[14]) + (sum >> 8)
142         res[14] = byte(sum & 0xff)
143         sum = int(xprv[15]) + int(res[15]) + (sum >> 8)
144         res[15] = byte(sum & 0xff)
145         sum = int(xprv[16]) + int(res[16]) + (sum >> 8)
146         res[16] = byte(sum & 0xff)
147         sum = int(xprv[17]) + int(res[17]) + (sum >> 8)
148         res[17] = byte(sum & 0xff)
149         sum = int(xprv[18]) + int(res[18]) + (sum >> 8)
150         res[18] = byte(sum & 0xff)
151         sum = int(xprv[19]) + int(res[19]) + (sum >> 8)
152         res[19] = byte(sum & 0xff)
153         sum = int(xprv[20]) + int(res[20]) + (sum >> 8)
154         res[20] = byte(sum & 0xff)
155         sum = int(xprv[21]) + int(res[21]) + (sum >> 8)
156         res[21] = byte(sum & 0xff)
157         sum = int(xprv[22]) + int(res[22]) + (sum >> 8)
158         res[22] = byte(sum & 0xff)
159         sum = int(xprv[23]) + int(res[23]) + (sum >> 8)
160         res[23] = byte(sum & 0xff)
161         sum = int(xprv[24]) + int(res[24]) + (sum >> 8)
162         res[24] = byte(sum & 0xff)
163         sum = int(xprv[25]) + int(res[25]) + (sum >> 8)
164         res[25] = byte(sum & 0xff)
165         sum = int(xprv[26]) + int(res[26]) + (sum >> 8)
166         res[26] = byte(sum & 0xff)
167         sum = int(xprv[27]) + int(res[27]) + (sum >> 8)
168         res[27] = byte(sum & 0xff)
169         sum = int(xprv[28]) + int(res[28]) + (sum >> 8)
170         res[28] = byte(sum & 0xff)
171         sum = int(xprv[29]) + int(res[29]) + (sum >> 8)
172         res[29] = byte(sum & 0xff)
173         sum = int(xprv[30]) + int(res[30]) + (sum >> 8)
174         res[30] = byte(sum & 0xff)
175         sum = int(xprv[31]) + int(res[31]) + (sum >> 8)
176         res[31] = byte(sum & 0xff)
177
178         if (sum >> 8) != 0 {
179                 panic("sum does not fit in 256-bit int")
180         }
181         return
182 }
183
184 // Child derives a child xpub based on `selector` string.
185 // The corresponding child xprv can be derived from the parent xprv
186 // using non-hardened derivation: `parentxprv.Child(sel, false)`.
187 func (xpub XPub) Child(sel []byte) (res XPub) {
188         h := hmac.New(sha512.New, xpub[32:])
189         h.Write([]byte{'N'})
190         h.Write(xpub[:32])
191         h.Write(sel)
192         h.Sum(res[:0])
193
194         pruneIntermediateScalar(res[:32])
195
196         var (
197                 f ecmath.Scalar
198                 F ecmath.Point
199         )
200         copy(f[:], res[:32])
201         F.ScMulBase(&f)
202
203         var (
204                 pubkey [32]byte
205                 P      ecmath.Point
206         )
207         copy(pubkey[:], xpub[:32])
208         _, ok := P.Decode(pubkey)
209         if !ok {
210                 panic("XPub should have been validated on initialization")
211         }
212
213         P.Add(&P, &F)
214         pubkey = P.Encode()
215         copy(res[:32], pubkey[:])
216
217         return
218 }
219
220 // Derive generates a child xprv by recursively deriving
221 // non-hardened child xprvs over the list of selectors:
222 // `Derive([a,b,c,...]) == Child(a).Child(b).Child(c)...`
223 func (xprv XPrv) Derive(path [][]byte) XPrv {
224         res := xprv
225         for _, p := range path {
226                 res = res.Child(p, false)
227         }
228         return res
229 }
230
231 // Derive generates a child xpub by recursively deriving
232 // non-hardened child xpubs over the list of selectors:
233 // `Derive([a,b,c,...]) == Child(a).Child(b).Child(c)...`
234 func (xpub XPub) Derive(path [][]byte) XPub {
235         res := xpub
236         for _, p := range path {
237                 res = res.Child(p)
238         }
239         return res
240 }
241
242 // Sign creates an EdDSA signature using expanded private key
243 // derived from the xprv.
244 func (xprv XPrv) Sign(msg []byte) []byte {
245         return Ed25519InnerSign(xprv.ExpandedPrivateKey(), msg)
246 }
247
248 // Verify checks an EdDSA signature using public key
249 // extracted from the first 32 bytes of the xpub.
250 func (xpub XPub) Verify(msg []byte, sig []byte) bool {
251         return ed25519.Verify(xpub.PublicKey(), msg, sig)
252 }
253
254 // ExpandedPrivateKey generates a 64-byte key where
255 // the first half is the scalar copied from xprv,
256 // and the second half is the `prefix` is generated via PRF
257 // from the xprv.
258 func (xprv XPrv) ExpandedPrivateKey() ExpandedPrivateKey {
259         var res [64]byte
260         h := hmac.New(sha512.New, []byte{'E', 'x', 'p', 'a', 'n', 'd'})
261         h.Write(xprv[:])
262         h.Sum(res[:0])
263         copy(res[:32], xprv[:32])
264         return res[:]
265 }
266
267 // PublicKey extracts the ed25519 public key from an xpub.
268 func (xpub XPub) PublicKey() ed25519.PublicKey {
269         return ed25519.PublicKey(xpub[:32])
270 }
271
272 // s must be >= 32 bytes long and gets rewritten in place.
273 // This is NOT the same pruning as in Ed25519: it additionally clears the third
274 // highest bit to ensure subkeys do not overflow the second highest bit.
275 func pruneRootScalar(s []byte) {
276         s[0] &= 248
277         s[31] &= 31 // clear top 3 bits
278         s[31] |= 64 // set second highest bit
279 }
280
281 // Clears lowest 3 bits and highest 23 bits of `f`.
282 func pruneIntermediateScalar(f []byte) {
283         f[0] &= 248 // clear bottom 3 bits
284         f[29] &= 1  // clear 7 high bits
285         f[30] = 0   // clear 8 bits
286         f[31] = 0   // clear 8 bits
287 }