OSDN Git Service

pass the lha-test10
[lha/olha.git] / huf.c
1 /***********************************************************
2         huf.c -- static Huffman
3 ***********************************************************/
4 #include <stdlib.h>
5 #include "ar.h"
6
7 #define NP (MAXDICBIT + 1)
8 #define NT (CODE_BIT + 3)
9 #define PBIT 4                  /* smallest integer such that (1U << PBIT) > NP */
10 #define TBIT 5                  /* smallest integer such that (1U << TBIT) > NT */
11 #if NT > NP
12 #define NPT NT
13 #else
14 #define NPT NP
15 #endif
16
17 ushort left[2 * NC - 1], right[2 * NC - 1];
18 static uchar *buf, c_len[NC], pt_len[NPT];
19 static uint bufsiz = 0, blocksize;
20 static ushort c_freq[2 * NC - 1], c_table[4096], c_code[NC],
21     p_freq[2 * NP - 1], pt_table[256], pt_code[NPT], t_freq[2 * NT - 1];
22
23 static int np;
24 static int pbit;
25
26 static void
27 init_parameter(struct lha_method *m)
28 {
29     np   = m->dicbit + 1;
30     pbit = m->pbit;
31 }
32
33 /***** encoding *****/
34
35 static void
36 count_t_freq(void)
37 {
38     int i, k, n, count;
39
40     for (i = 0; i < NT; i++)
41         t_freq[i] = 0;
42     n = NC;
43     while (n > 0 && c_len[n - 1] == 0)
44         n--;
45     i = 0;
46     while (i < n) {
47         k = c_len[i++];
48         if (k == 0) {
49             count = 1;
50             while (i < n && c_len[i] == 0) {
51                 i++;
52                 count++;
53             }
54             if (count <= 2)
55                 t_freq[0] += count;
56             else if (count <= 18)
57                 t_freq[1]++;
58             else if (count == 19) {
59                 t_freq[0]++;
60                 t_freq[1]++;
61             }
62             else
63                 t_freq[2]++;
64         }
65         else
66             t_freq[k + 2]++;
67     }
68 }
69
70 static void
71 write_pt_len(int n, int nbit, int i_special)
72 {
73     int i, k;
74
75     while (n > 0 && pt_len[n - 1] == 0)
76         n--;
77     putbits(nbit, n);
78     i = 0;
79     while (i < n) {
80         k = pt_len[i++];
81         if (k <= 6)
82             putbits(3, k);
83         else
84             putbits(k - 3, (1U << (k - 3)) - 2);
85         if (i == i_special) {
86             while (i < 6 && pt_len[i] == 0)
87                 i++;
88             putbits(2, (i - 3) & 3);
89         }
90     }
91 }
92
93 static void
94 write_c_len(void)
95 {
96     int i, k, n, count;
97
98     n = NC;
99     while (n > 0 && c_len[n - 1] == 0)
100         n--;
101     putbits(CBIT, n);
102     i = 0;
103     while (i < n) {
104         k = c_len[i++];
105         if (k == 0) {
106             count = 1;
107             while (i < n && c_len[i] == 0) {
108                 i++;
109                 count++;
110             }
111             if (count <= 2) {
112                 for (k = 0; k < count; k++)
113                     putbits(pt_len[0], pt_code[0]);
114             }
115             else if (count <= 18) {
116                 putbits(pt_len[1], pt_code[1]);
117                 putbits(4, count - 3);
118             }
119             else if (count == 19) {
120                 putbits(pt_len[0], pt_code[0]);
121                 putbits(pt_len[1], pt_code[1]);
122                 putbits(4, 15);
123             }
124             else {
125                 putbits(pt_len[2], pt_code[2]);
126                 putbits(CBIT, count - 20);
127             }
128         }
129         else
130             putbits(pt_len[k + 2], pt_code[k + 2]);
131     }
132 }
133
134 static void
135 encode_c(int c)
136 {
137     putbits(c_len[c], c_code[c]);
138 }
139
140 static void
141 encode_p(uint p)
142 {
143     uint c, q;
144
145     c = 0;
146     q = p;
147     while (q) {
148         q >>= 1;
149         c++;
150     }
151     putbits(pt_len[c], pt_code[c]);
152     if (c > 1)
153         putbits(c - 1, p & (0xFFFFU >> (17 - c)));
154 }
155
156 static void
157 send_block(void)
158 {
159     uint i, k, flags, root, pos, size;
160
161     root = make_tree(NC, c_freq, c_len, c_code);
162     size = c_freq[root];
163     putbits(16, size);
164     if (root >= NC) {
165         count_t_freq();
166         root = make_tree(NT, t_freq, pt_len, pt_code);
167         if (root >= NT) {
168             write_pt_len(NT, TBIT, 3);
169         }
170         else {
171             putbits(TBIT, 0);
172             putbits(TBIT, root);
173         }
174         write_c_len();
175     }
176     else {
177         putbits(TBIT, 0);
178         putbits(TBIT, 0);
179         putbits(CBIT, 0);
180         putbits(CBIT, root);
181     }
182     root = make_tree(np, p_freq, pt_len, pt_code);
183     if (root >= np) {
184         write_pt_len(np, pbit, -1);
185     }
186     else {
187         putbits(pbit, 0);
188         putbits(pbit, root);
189     }
190     pos = 0;
191     for (i = 0; i < size; i++) {
192         if (i % CHAR_BIT == 0)
193             flags = buf[pos++];
194         else
195             flags <<= 1;
196         if (flags & (1U << (CHAR_BIT - 1))) {
197             encode_c(buf[pos++] + (1U << CHAR_BIT));
198             k = buf[pos++] << CHAR_BIT;
199             k += buf[pos++];
200             encode_p(k);
201         }
202         else
203             encode_c(buf[pos++]);
204         if (unpackable)
205             return;
206     }
207     for (i = 0; i < NC; i++)
208         c_freq[i] = 0;
209     for (i = 0; i < np; i++)
210         p_freq[i] = 0;
211 }
212
213 static uint output_pos, output_mask;
214
215 void
216 output(uint c, uint p)
217 {
218     static uint cpos;
219
220     if ((output_mask >>= 1) == 0) {
221         output_mask = 1U << (CHAR_BIT - 1);
222         if (output_pos >= bufsiz - 3 * CHAR_BIT) {
223             send_block();
224             if (unpackable)
225                 return;
226             output_pos = 0;
227         }
228         cpos = output_pos++;
229         buf[cpos] = 0;
230     }
231     buf[output_pos++] = (uchar) c;
232     c_freq[c]++;
233     if (c >= (1U << CHAR_BIT)) {
234         buf[cpos] |= output_mask;
235         buf[output_pos++] = (uchar) (p >> CHAR_BIT);
236         buf[output_pos++] = (uchar) p;
237         c = 0;
238         while (p) {
239             p >>= 1;
240             c++;
241         }
242         p_freq[c]++;
243     }
244 }
245
246 void
247 huf_encode_start(struct lha_method *m)
248 {
249     int i;
250
251     init_parameter(m);
252
253     if (bufsiz == 0) {
254         bufsiz = 16 * 1024U;
255         while ((buf = malloc(bufsiz)) == NULL) {
256             bufsiz = (bufsiz / 10U) * 9U;
257             if (bufsiz < 4 * 1024U)
258                 error("Out of memory.");
259         }
260     }
261     buf[0] = 0;
262     for (i = 0; i < NC; i++)
263         c_freq[i] = 0;
264     for (i = 0; i < np; i++)
265         p_freq[i] = 0;
266     output_pos = output_mask = 0;
267     init_putbits();
268 }
269
270 void
271 huf_encode_end(void)
272 {
273     if (!unpackable) {
274         send_block();
275         putbits(CHAR_BIT - 1, 0);       /* flush remaining bits */
276     }
277 }
278
279 /***** decoding *****/
280
281 static void
282 read_pt_len(int nn, int nbit, int i_special)
283 {
284     int i, c, n;
285     uint mask;
286
287     n = getbits(nbit);
288     if (n == 0) {
289         c = getbits(nbit);
290         for (i = 0; i < nn; i++)
291             pt_len[i] = 0;
292         for (i = 0; i < 256; i++)
293             pt_table[i] = c;
294     }
295     else {
296         i = 0;
297         while (i < n) {
298             c = bitbuf >> (BITBUFSIZ - 3);
299             if (c == 7) {
300                 mask = 1U << (BITBUFSIZ - 1 - 3);
301                 while (mask & bitbuf) {
302                     mask >>= 1;
303                     c++;
304                 }
305             }
306             fillbuf((c < 7) ? 3 : c - 3);
307             pt_len[i++] = c;
308             if (i == i_special) {
309                 c = getbits(2);
310                 while (--c >= 0)
311                     pt_len[i++] = 0;
312             }
313         }
314         while (i < nn)
315             pt_len[i++] = 0;
316         make_table(nn, pt_len, 8, pt_table);
317     }
318 }
319
320 static void
321 read_c_len(void)
322 {
323     int i, c, n;
324     uint mask;
325
326     n = getbits(CBIT);
327     if (n == 0) {
328         c = getbits(CBIT);
329         for (i = 0; i < NC; i++)
330             c_len[i] = 0;
331         for (i = 0; i < 4096; i++)
332             c_table[i] = c;
333     }
334     else {
335         i = 0;
336         while (i < n) {
337             c = pt_table[bitbuf >> (BITBUFSIZ - 8)];
338             if (c >= NT) {
339                 mask = 1U << (BITBUFSIZ - 1 - 8);
340                 do {
341                     if (bitbuf & mask)
342                         c = right[c];
343                     else
344                         c = left[c];
345                     mask >>= 1;
346                 } while (c >= NT);
347             }
348             fillbuf(pt_len[c]);
349             if (c <= 2) {
350                 if (c == 0)
351                     c = 1;
352                 else if (c == 1)
353                     c = getbits(4) + 3;
354                 else
355                     c = getbits(CBIT) + 20;
356                 while (--c >= 0)
357                     c_len[i++] = 0;
358             }
359             else
360                 c_len[i++] = c - 2;
361         }
362         while (i < NC)
363             c_len[i++] = 0;
364         make_table(NC, c_len, 12, c_table);
365     }
366 }
367
368 uint
369 decode_c(void)
370 {
371     uint j, mask;
372
373     if (blocksize == 0) {
374         blocksize = getbits(16);
375         read_pt_len(NT, TBIT, 3);
376         read_c_len();
377         read_pt_len(np, pbit, -1);
378     }
379     blocksize--;
380     j = c_table[bitbuf >> (BITBUFSIZ - 12)];
381     if (j >= NC) {
382         mask = 1U << (BITBUFSIZ - 1 - 12);
383         do {
384             if (bitbuf & mask)
385                 j = right[j];
386             else
387                 j = left[j];
388             mask >>= 1;
389         } while (j >= NC);
390     }
391     fillbuf(c_len[j]);
392     return j;
393 }
394
395 uint
396 decode_p(void)
397 {
398     uint j, mask;
399
400     j = pt_table[bitbuf >> (BITBUFSIZ - 8)];
401     if (j >= np) {
402         mask = 1U << (BITBUFSIZ - 1 - 8);
403         do {
404             if (bitbuf & mask)
405                 j = right[j];
406             else
407                 j = left[j];
408             mask >>= 1;
409         } while (j >= np);
410     }
411     fillbuf(pt_len[j]);
412     if (j != 0)
413         j = (1U << (j - 1)) + getbits(j - 1);
414     return j;
415 }
416
417 void
418 huf_decode_start(struct lha_method *m)
419 {
420     init_parameter(m);
421     init_getbits();
422     blocksize = 0;
423 }