1 /*
2  * WavPack lossless audio decoder
3  * Copyright (c) 2006,2011 Konstantin Shishkov
4  * Copyright (c) 2020 David Bryant
5  *
6  * This file is part of FFmpeg.
7  *
8  * FFmpeg is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU Lesser General Public
10  * License as published by the Free Software Foundation; either
11  * version 2.1 of the License, or (at your option) any later version.
12  *
13  * FFmpeg is distributed in the hope that it will be useful,
14  * but WITHOUT ANY WARRANTY; without even the implied warranty of
15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16  * Lesser General Public License for more details.
17  *
18  * You should have received a copy of the GNU Lesser General Public
19  * License along with FFmpeg; if not, write to the Free Software
20  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21  */
22 
23 #include "libavutil/buffer.h"
24 #include "libavutil/channel_layout.h"
25 
26 #define BITSTREAM_READER_LE
27 #include "avcodec.h"
28 #include "bytestream.h"
29 #include "codec_internal.h"
30 #include "get_bits.h"
31 #include "thread.h"
32 #include "threadframe.h"
33 #include "unary.h"
34 #include "wavpack.h"
35 #include "dsd.h"
36 
37 /**
38  * @file
39  * WavPack lossless audio decoder
40  */
41 
42 #define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
43 
44 #define PTABLE_BITS 8
45 #define PTABLE_BINS (1<<PTABLE_BITS)
46 #define PTABLE_MASK (PTABLE_BINS-1)
47 
48 #define UP   0x010000fe
49 #define DOWN 0x00010000
50 #define DECAY 8
51 
52 #define PRECISION 20
53 #define VALUE_ONE (1 << PRECISION)
54 #define PRECISION_USE 12
55 
56 #define RATE_S 20
57 
58 #define MAX_HISTORY_BITS    5
59 #define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
60 #define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
61 
62 typedef enum {
63     MODULATION_PCM,     // pulse code modulation
64     MODULATION_DSD      // pulse density modulation (aka DSD)
65 } Modulation;
66 
67 typedef struct WavpackFrameContext {
68     AVCodecContext *avctx;
69     int frame_flags;
70     int stereo, stereo_in;
71     int joint;
72     uint32_t CRC;
73     GetBitContext gb;
74     int got_extra_bits;
75     uint32_t crc_extra_bits;
76     GetBitContext gb_extra_bits;
77     int samples;
78     int terms;
79     Decorr decorr[MAX_TERMS];
80     int zero, one, zeroes;
81     int extra_bits;
82     int and, or, shift;
83     int post_shift;
84     int hybrid, hybrid_bitrate;
85     int hybrid_maxclip, hybrid_minclip;
86     int float_flag;
87     int float_shift;
88     int float_max_exp;
89     WvChannel ch[2];
90 
91     GetByteContext gbyte;
92     int ptable [PTABLE_BINS];
93     uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
94     uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
95     uint8_t probabilities[MAX_HISTORY_BINS][256];
96     uint8_t *value_lookup[MAX_HISTORY_BINS];
97 } WavpackFrameContext;
98 
99 #define WV_MAX_FRAME_DECODERS 14
100 
101 typedef struct WavpackContext {
102     AVCodecContext *avctx;
103 
104     WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
105     int fdec_num;
106 
107     int block;
108     int samples;
109     int ch_offset;
110 
111     AVFrame *frame;
112     ThreadFrame curr_frame, prev_frame;
113     Modulation modulation;
114 
115     AVBufferRef *dsd_ref;
116     DSDContext *dsdctx;
117     int dsd_channels;
118 } WavpackContext;
119 
120 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
121 
get_tail(GetBitContext *gb, int k)122 static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
123 {
124     int p, e, res;
125 
126     if (k < 1)
127         return 0;
128     p   = av_log2(k);
129     e   = (1 << (p + 1)) - k - 1;
130     res = get_bitsz(gb, p);
131     if (res >= e)
132         res = res * 2U - e + get_bits1(gb);
133     return res;
134 }
135 
update_error_limit(WavpackFrameContext *ctx)136 static int update_error_limit(WavpackFrameContext *ctx)
137 {
138     int i, br[2], sl[2];
139 
140     for (i = 0; i <= ctx->stereo_in; i++) {
141         if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
142             return AVERROR_INVALIDDATA;
143         ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
144         br[i]                   = ctx->ch[i].bitrate_acc >> 16;
145         sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
146     }
147     if (ctx->stereo_in && ctx->hybrid_bitrate) {
148         int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
149         if (balance > br[0]) {
150             br[1] = br[0] * 2;
151             br[0] = 0;
152         } else if (-balance > br[0]) {
153             br[0]  *= 2;
154             br[1]   = 0;
155         } else {
156             br[1] = br[0] + balance;
157             br[0] = br[0] - balance;
158         }
159     }
160     for (i = 0; i <= ctx->stereo_in; i++) {
161         if (ctx->hybrid_bitrate) {
162             if (sl[i] - br[i] > -0x100)
163                 ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
164             else
165                 ctx->ch[i].error_limit = 0;
166         } else {
167             ctx->ch[i].error_limit = wp_exp2(br[i]);
168         }
169     }
170 
171     return 0;
172 }
173 
wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb, int channel, int *last)174 static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
175                         int channel, int *last)
176 {
177     int t, t2;
178     int sign, base, add, ret;
179     WvChannel *c = &ctx->ch[channel];
180 
181     *last = 0;
182 
183     if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
184         !ctx->zero && !ctx->one) {
185         if (ctx->zeroes) {
186             ctx->zeroes--;
187             if (ctx->zeroes) {
188                 c->slow_level -= LEVEL_DECAY(c->slow_level);
189                 return 0;
190             }
191         } else {
192             t = get_unary_0_33(gb);
193             if (t >= 2) {
194                 if (t >= 32 || get_bits_left(gb) < t - 1)
195                     goto error;
196                 t = get_bits_long(gb, t - 1) | (1 << (t - 1));
197             } else {
198                 if (get_bits_left(gb) < 0)
199                     goto error;
200             }
201             ctx->zeroes = t;
202             if (ctx->zeroes) {
203                 memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
204                 memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
205                 c->slow_level -= LEVEL_DECAY(c->slow_level);
206                 return 0;
207             }
208         }
209     }
210 
211     if (ctx->zero) {
212         t         = 0;
213         ctx->zero = 0;
214     } else {
215         t = get_unary_0_33(gb);
216         if (get_bits_left(gb) < 0)
217             goto error;
218         if (t == 16) {
219             t2 = get_unary_0_33(gb);
220             if (t2 < 2) {
221                 if (get_bits_left(gb) < 0)
222                     goto error;
223                 t += t2;
224             } else {
225                 if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
226                     goto error;
227                 t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
228             }
229         }
230 
231         if (ctx->one) {
232             ctx->one = t & 1;
233             t        = (t >> 1) + 1;
234         } else {
235             ctx->one = t & 1;
236             t      >>= 1;
237         }
238         ctx->zero = !ctx->one;
239     }
240 
241     if (ctx->hybrid && !channel) {
242         if (update_error_limit(ctx) < 0)
243             goto error;
244     }
245 
246     if (!t) {
247         base = 0;
248         add  = GET_MED(0) - 1;
249         DEC_MED(0);
250     } else if (t == 1) {
251         base = GET_MED(0);
252         add  = GET_MED(1) - 1;
253         INC_MED(0);
254         DEC_MED(1);
255     } else if (t == 2) {
256         base = GET_MED(0) + GET_MED(1);
257         add  = GET_MED(2) - 1;
258         INC_MED(0);
259         INC_MED(1);
260         DEC_MED(2);
261     } else {
262         base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
263         add  = GET_MED(2) - 1;
264         INC_MED(0);
265         INC_MED(1);
266         INC_MED(2);
267     }
268     if (!c->error_limit) {
269         if (add >= 0x2000000U) {
270             av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
271             goto error;
272         }
273         ret = base + get_tail(gb, add);
274         if (get_bits_left(gb) <= 0)
275             goto error;
276     } else {
277         int mid = (base * 2U + add + 1) >> 1;
278         while (add > c->error_limit) {
279             if (get_bits_left(gb) <= 0)
280                 goto error;
281             if (get_bits1(gb)) {
282                 add -= (mid - (unsigned)base);
283                 base = mid;
284             } else
285                 add = mid - (unsigned)base - 1;
286             mid = (base * 2U + add + 1) >> 1;
287         }
288         ret = mid;
289     }
290     sign = get_bits1(gb);
291     if (ctx->hybrid_bitrate)
292         c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
293     return sign ? ~ret : ret;
294 
295 error:
296     ret = get_bits_left(gb);
297     if (ret <= 0) {
298         av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
299     }
300     *last = 1;
301     return 0;
302 }
303 
wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc, unsigned S)304 static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
305                                        unsigned S)
306 {
307     unsigned bit;
308 
309     if (s->extra_bits) {
310         S *= 1 << s->extra_bits;
311 
312         if (s->got_extra_bits &&
313             get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
314             S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
315             *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
316         }
317     }
318 
319     bit = (S & s->and) | s->or;
320     bit = ((S + bit) << s->shift) - bit;
321 
322     if (s->hybrid)
323         bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
324 
325     return bit << s->post_shift;
326 }
327 
wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)328 static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
329 {
330     union {
331         float    f;
332         uint32_t u;
333     } value;
334 
335     unsigned int sign;
336     int exp = s->float_max_exp;
337 
338     if (s->got_extra_bits) {
339         const int max_bits  = 1 + 23 + 8 + 1;
340         const int left_bits = get_bits_left(&s->gb_extra_bits);
341 
342         if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
343             return 0.0;
344     }
345 
346     if (S) {
347         S  *= 1U << s->float_shift;
348         sign = S < 0;
349         if (sign)
350             S = -(unsigned)S;
351         if (S >= 0x1000000U) {
352             if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
353                 S = get_bits(&s->gb_extra_bits, 23);
354             else
355                 S = 0;
356             exp = 255;
357         } else if (exp) {
358             int shift = 23 - av_log2(S);
359             exp = s->float_max_exp;
360             if (exp <= shift)
361                 shift = --exp;
362             exp -= shift;
363 
364             if (shift) {
365                 S <<= shift;
366                 if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
367                     (s->got_extra_bits &&
368                      (s->float_flag & WV_FLT_SHIFT_SAME) &&
369                      get_bits1(&s->gb_extra_bits))) {
370                     S |= (1 << shift) - 1;
371                 } else if (s->got_extra_bits &&
372                            (s->float_flag & WV_FLT_SHIFT_SENT)) {
373                     S |= get_bits(&s->gb_extra_bits, shift);
374                 }
375             }
376         } else {
377             exp = s->float_max_exp;
378         }
379         S &= 0x7fffff;
380     } else {
381         sign = 0;
382         exp  = 0;
383         if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
384             if (get_bits1(&s->gb_extra_bits)) {
385                 S = get_bits(&s->gb_extra_bits, 23);
386                 if (s->float_max_exp >= 25)
387                     exp = get_bits(&s->gb_extra_bits, 8);
388                 sign = get_bits1(&s->gb_extra_bits);
389             } else {
390                 if (s->float_flag & WV_FLT_ZERO_SIGN)
391                     sign = get_bits1(&s->gb_extra_bits);
392             }
393         }
394     }
395 
396     *crc = *crc * 27 + S * 9 + exp * 3 + sign;
397 
398     value.u = (sign << 31) | (exp << 23) | S;
399     return value.f;
400 }
401 
wv_check_crc(WavpackFrameContext *s, uint32_t crc, uint32_t crc_extra_bits)402 static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
403                                uint32_t crc_extra_bits)
404 {
405     if (crc != s->CRC) {
406         av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
407         return AVERROR_INVALIDDATA;
408     }
409     if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
410         av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
411         return AVERROR_INVALIDDATA;
412     }
413 
414     return 0;
415 }
416 
init_ptable(int *table, int rate_i, int rate_s)417 static void init_ptable(int *table, int rate_i, int rate_s)
418 {
419     int value = 0x808000, rate = rate_i << 8;
420 
421     for (int c = (rate + 128) >> 8; c--;)
422         value += (DOWN - value) >> DECAY;
423 
424     for (int i = 0; i < PTABLE_BINS/2; i++) {
425         table[i] = value;
426         table[PTABLE_BINS-1-i] = 0x100ffff - value;
427 
428         if (value > 0x010000) {
429             rate += (rate * rate_s + 128) >> 8;
430 
431             for (int c = (rate + 64) >> 7; c--;)
432                 value += (DOWN - value) >> DECAY;
433         }
434     }
435 }
436 
437 typedef struct {
438     int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
439     unsigned int byte;
440 } DSDfilters;
441 
wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)442 static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
443 {
444     uint32_t checksum = 0xFFFFFFFF;
445     uint8_t *dst_l = dst_left, *dst_r = dst_right;
446     int total_samples = s->samples, stereo = dst_r ? 1 : 0;
447     DSDfilters filters[2], *sp = filters;
448     int rate_i, rate_s;
449     uint32_t low, high, value;
450 
451     if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
452         return AVERROR_INVALIDDATA;
453 
454     rate_i = bytestream2_get_byte(&s->gbyte);
455     rate_s = bytestream2_get_byte(&s->gbyte);
456 
457     if (rate_s != RATE_S)
458         return AVERROR_INVALIDDATA;
459 
460     init_ptable(s->ptable, rate_i, rate_s);
461 
462     for (int channel = 0; channel < stereo + 1; channel++) {
463         DSDfilters *sp = filters + channel;
464 
465         sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
466         sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
467         sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
468         sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
469         sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
470         sp->fltr6 = 0;
471         sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
472         sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
473         sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
474     }
475 
476     value = bytestream2_get_be32(&s->gbyte);
477     high = 0xffffffff;
478     low = 0x0;
479 
480     while (total_samples--) {
481         int bitcount = 8;
482 
483         sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
484 
485         if (stereo)
486             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
487 
488         while (bitcount--) {
489             int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
490             uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
491 
492             if (value <= split) {
493                 high = split;
494                 *pp += (UP - *pp) >> DECAY;
495                 sp[0].fltr0 = -1;
496             } else {
497                 low = split + 1;
498                 *pp += (DOWN - *pp) >> DECAY;
499                 sp[0].fltr0 = 0;
500             }
501 
502             if (DSD_BYTE_READY(high, low) && !bytestream2_get_bytes_left(&s->gbyte))
503                 return AVERROR_INVALIDDATA;
504             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
505                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
506                 high = (high << 8) | 0xff;
507                 low <<= 8;
508             }
509 
510             sp[0].value += sp[0].fltr6 * 8;
511             sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
512             sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
513                 ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
514             sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
515             sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
516             sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
517             sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
518             sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
519             sp[0].fltr5 += sp[0].value;
520             sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
521             sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
522 
523             if (!stereo)
524                 continue;
525 
526             pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
527             split = low + ((high - low) >> 8) * (*pp >> 16);
528 
529             if (value <= split) {
530                 high = split;
531                 *pp += (UP - *pp) >> DECAY;
532                 sp[1].fltr0 = -1;
533             } else {
534                 low = split + 1;
535                 *pp += (DOWN - *pp) >> DECAY;
536                 sp[1].fltr0 = 0;
537             }
538 
539             if (DSD_BYTE_READY(high, low) && !bytestream2_get_bytes_left(&s->gbyte))
540                 return AVERROR_INVALIDDATA;
541             while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
542                 value = (value << 8) | bytestream2_get_byte(&s->gbyte);
543                 high = (high << 8) | 0xff;
544                 low <<= 8;
545             }
546 
547             sp[1].value += sp[1].fltr6 * 8;
548             sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
549             sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
550                 ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
551             sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
552             sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
553             sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
554             sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
555             sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
556             sp[1].fltr5 += sp[1].value;
557             sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
558             sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
559         }
560 
561         checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
562         sp[0].factor -= (sp[0].factor + 512) >> 10;
563         dst_l += 4;
564 
565         if (stereo) {
566             checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
567             filters[1].factor -= (filters[1].factor + 512) >> 10;
568             dst_r += 4;
569         }
570     }
571 
572     if (wv_check_crc(s, checksum, 0)) {
573         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
574             return AVERROR_INVALIDDATA;
575 
576         memset(dst_left, 0x69, s->samples * 4);
577 
578         if (dst_r)
579             memset(dst_right, 0x69, s->samples * 4);
580     }
581 
582     return 0;
583 }
584 
wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)585 static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
586 {
587     uint8_t *dst_l = dst_left, *dst_r = dst_right;
588     uint8_t history_bits, max_probability;
589     int total_summed_probabilities  = 0;
590     int total_samples               = s->samples;
591     uint8_t *vlb                    = s->value_lookup_buffer;
592     int history_bins, p0, p1, chan;
593     uint32_t checksum               = 0xFFFFFFFF;
594     uint32_t low, high, value;
595 
596     if (!bytestream2_get_bytes_left(&s->gbyte))
597         return AVERROR_INVALIDDATA;
598 
599     history_bits = bytestream2_get_byte(&s->gbyte);
600 
601     if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
602         return AVERROR_INVALIDDATA;
603 
604     history_bins = 1 << history_bits;
605     max_probability = bytestream2_get_byte(&s->gbyte);
606 
607     if (max_probability < 0xff) {
608         uint8_t *outptr = (uint8_t *)s->probabilities;
609         uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
610 
611         while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
612             int code = bytestream2_get_byte(&s->gbyte);
613 
614             if (code > max_probability) {
615                 int zcount = code - max_probability;
616 
617                 while (outptr < outend && zcount--)
618                     *outptr++ = 0;
619             } else if (code) {
620                 *outptr++ = code;
621             }
622             else {
623                 break;
624             }
625         }
626 
627         if (outptr < outend ||
628             (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
629                 return AVERROR_INVALIDDATA;
630     } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
631         bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
632             sizeof(*s->probabilities) * history_bins);
633     } else {
634         return AVERROR_INVALIDDATA;
635     }
636 
637     for (p0 = 0; p0 < history_bins; p0++) {
638         int32_t sum_values = 0;
639 
640         for (int i = 0; i < 256; i++)
641             s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
642 
643         if (sum_values) {
644             total_summed_probabilities += sum_values;
645 
646             if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
647                 return AVERROR_INVALIDDATA;
648 
649             s->value_lookup[p0] = vlb;
650 
651             for (int i = 0; i < 256; i++) {
652                 int c = s->probabilities[p0][i];
653 
654                 while (c--)
655                     *vlb++ = i;
656             }
657         }
658     }
659 
660     if (bytestream2_get_bytes_left(&s->gbyte) < 4)
661         return AVERROR_INVALIDDATA;
662 
663     chan = p0 = p1 = 0;
664     low = 0; high = 0xffffffff;
665     value = bytestream2_get_be32(&s->gbyte);
666 
667     if (dst_r)
668         total_samples *= 2;
669 
670     while (total_samples--) {
671         unsigned int mult, index, code;
672 
673         if (!s->summed_probabilities[p0][255])
674             return AVERROR_INVALIDDATA;
675 
676         mult = (high - low) / s->summed_probabilities[p0][255];
677 
678         if (!mult) {
679             if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
680                 value = bytestream2_get_be32(&s->gbyte);
681 
682             low = 0;
683             high = 0xffffffff;
684             mult = high / s->summed_probabilities[p0][255];
685 
686             if (!mult)
687                 return AVERROR_INVALIDDATA;
688         }
689 
690         index = (value - low) / mult;
691 
692         if (index >= s->summed_probabilities[p0][255])
693             return AVERROR_INVALIDDATA;
694 
695         if (!dst_r) {
696             if ((*dst_l = code = s->value_lookup[p0][index]))
697                 low += s->summed_probabilities[p0][code-1] * mult;
698 
699             dst_l += 4;
700         } else {
701             if ((code = s->value_lookup[p0][index]))
702                 low += s->summed_probabilities[p0][code-1] * mult;
703 
704             if (chan) {
705                 *dst_r = code;
706                 dst_r += 4;
707             }
708             else {
709                 *dst_l = code;
710                 dst_l += 4;
711             }
712 
713             chan ^= 1;
714         }
715 
716         high = low + s->probabilities[p0][code] * mult - 1;
717         checksum += (checksum << 1) + code;
718 
719         if (!dst_r) {
720             p0 = code & (history_bins-1);
721         } else {
722             p0 = p1;
723             p1 = code & (history_bins-1);
724         }
725 
726         while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
727             value = (value << 8) | bytestream2_get_byte(&s->gbyte);
728             high = (high << 8) | 0xff;
729             low <<= 8;
730         }
731     }
732 
733     if (wv_check_crc(s, checksum, 0)) {
734         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
735             return AVERROR_INVALIDDATA;
736 
737         memset(dst_left, 0x69, s->samples * 4);
738 
739         if (dst_r)
740             memset(dst_right, 0x69, s->samples * 4);
741     }
742 
743     return 0;
744 }
745 
wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)746 static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
747 {
748     uint8_t *dst_l = dst_left, *dst_r = dst_right;
749     int total_samples           = s->samples;
750     uint32_t checksum           = 0xFFFFFFFF;
751 
752     if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
753         return AVERROR_INVALIDDATA;
754 
755     while (total_samples--) {
756         checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
757         dst_l += 4;
758 
759         if (dst_r) {
760             checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
761             dst_r += 4;
762         }
763     }
764 
765     if (wv_check_crc(s, checksum, 0)) {
766         if (s->avctx->err_recognition & AV_EF_CRCCHECK)
767             return AVERROR_INVALIDDATA;
768 
769         memset(dst_left, 0x69, s->samples * 4);
770 
771         if (dst_r)
772             memset(dst_right, 0x69, s->samples * 4);
773     }
774 
775     return 0;
776 }
777 
wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb, void *dst_l, void *dst_r, const int type)778 static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
779                                    void *dst_l, void *dst_r, const int type)
780 {
781     int i, j, count = 0;
782     int last, t;
783     int A, B, L, L2, R, R2;
784     int pos                 = 0;
785     uint32_t crc            = 0xFFFFFFFF;
786     uint32_t crc_extra_bits = 0xFFFFFFFF;
787     int16_t *dst16_l        = dst_l;
788     int16_t *dst16_r        = dst_r;
789     int32_t *dst32_l        = dst_l;
790     int32_t *dst32_r        = dst_r;
791     float *dstfl_l          = dst_l;
792     float *dstfl_r          = dst_r;
793 
794     s->one = s->zero = s->zeroes = 0;
795     do {
796         L = wv_get_value(s, gb, 0, &last);
797         if (last)
798             break;
799         R = wv_get_value(s, gb, 1, &last);
800         if (last)
801             break;
802         for (i = 0; i < s->terms; i++) {
803             t = s->decorr[i].value;
804             if (t > 0) {
805                 if (t > 8) {
806                     if (t & 1) {
807                         A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
808                         B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
809                     } else {
810                         A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
811                         B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
812                     }
813                     s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
814                     s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
815                     j                        = 0;
816                 } else {
817                     A = s->decorr[i].samplesA[pos];
818                     B = s->decorr[i].samplesB[pos];
819                     j = (pos + t) & 7;
820                 }
821                 if (type != AV_SAMPLE_FMT_S16P) {
822                     L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
823                     R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
824                 } else {
825                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
826                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
827                 }
828                 if (A && L)
829                     s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
830                 if (B && R)
831                     s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
832                 s->decorr[i].samplesA[j] = L = L2;
833                 s->decorr[i].samplesB[j] = R = R2;
834             } else if (t == -1) {
835                 if (type != AV_SAMPLE_FMT_S16P)
836                     L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
837                 else
838                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
839                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
840                 L = L2;
841                 if (type != AV_SAMPLE_FMT_S16P)
842                     R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
843                 else
844                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
845                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
846                 R                        = R2;
847                 s->decorr[i].samplesA[0] = R;
848             } else {
849                 if (type != AV_SAMPLE_FMT_S16P)
850                     R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
851                 else
852                     R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
853                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
854                 R = R2;
855 
856                 if (t == -3) {
857                     R2                       = s->decorr[i].samplesA[0];
858                     s->decorr[i].samplesA[0] = R;
859                 }
860 
861                 if (type != AV_SAMPLE_FMT_S16P)
862                     L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
863                 else
864                     L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
865                 UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
866                 L                        = L2;
867                 s->decorr[i].samplesB[0] = L;
868             }
869         }
870 
871         if (type == AV_SAMPLE_FMT_S16P) {
872             if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
873                 av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
874                 return AVERROR_INVALIDDATA;
875             }
876         }
877 
878         pos = (pos + 1) & 7;
879         if (s->joint)
880             L += (unsigned)(R -= (unsigned)(L >> 1));
881         crc = (crc * 3 + L) * 3 + R;
882 
883         if (type == AV_SAMPLE_FMT_FLTP) {
884             *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
885             *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
886         } else if (type == AV_SAMPLE_FMT_S32P) {
887             *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
888             *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
889         } else {
890             *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
891             *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
892         }
893         count++;
894     } while (!last && count < s->samples);
895 
896     if (last && count < s->samples) {
897         int size = av_get_bytes_per_sample(type);
898         memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
899         memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
900     }
901 
902     if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
903         wv_check_crc(s, crc, crc_extra_bits))
904         return AVERROR_INVALIDDATA;
905 
906     return 0;
907 }
908 
wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb, void *dst, const int type)909 static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
910                                  void *dst, const int type)
911 {
912     int i, j, count = 0;
913     int last, t;
914     int A, S, T;
915     int pos                  = 0;
916     uint32_t crc             = 0xFFFFFFFF;
917     uint32_t crc_extra_bits  = 0xFFFFFFFF;
918     int16_t *dst16           = dst;
919     int32_t *dst32           = dst;
920     float *dstfl             = dst;
921 
922     s->one = s->zero = s->zeroes = 0;
923     do {
924         T = wv_get_value(s, gb, 0, &last);
925         S = 0;
926         if (last)
927             break;
928         for (i = 0; i < s->terms; i++) {
929             t = s->decorr[i].value;
930             if (t > 8) {
931                 if (t & 1)
932                     A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
933                 else
934                     A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
935                 s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
936                 j                        = 0;
937             } else {
938                 A = s->decorr[i].samplesA[pos];
939                 j = (pos + t) & 7;
940             }
941             if (type != AV_SAMPLE_FMT_S16P)
942                 S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
943             else
944                 S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
945             if (A && T)
946                 s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
947             s->decorr[i].samplesA[j] = T = S;
948         }
949         pos = (pos + 1) & 7;
950         crc = crc * 3 + S;
951 
952         if (type == AV_SAMPLE_FMT_FLTP) {
953             *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
954         } else if (type == AV_SAMPLE_FMT_S32P) {
955             *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
956         } else {
957             *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
958         }
959         count++;
960     } while (!last && count < s->samples);
961 
962     if (last && count < s->samples) {
963         int size = av_get_bytes_per_sample(type);
964         memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
965     }
966 
967     if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
968         int ret = wv_check_crc(s, crc, crc_extra_bits);
969         if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
970             return ret;
971     }
972 
973     return 0;
974 }
975 
wv_alloc_frame_context(WavpackContext *c)976 static av_cold int wv_alloc_frame_context(WavpackContext *c)
977 {
978     if (c->fdec_num == WV_MAX_FRAME_DECODERS)
979         return -1;
980 
981     c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
982     if (!c->fdec[c->fdec_num])
983         return -1;
984     c->fdec_num++;
985     c->fdec[c->fdec_num - 1]->avctx = c->avctx;
986 
987     return 0;
988 }
989 
wv_dsd_reset(WavpackContext *s, int channels)990 static int wv_dsd_reset(WavpackContext *s, int channels)
991 {
992     int i;
993 
994     s->dsdctx = NULL;
995     s->dsd_channels = 0;
996     av_buffer_unref(&s->dsd_ref);
997 
998     if (!channels)
999         return 0;
1000 
1001     if (channels > INT_MAX / sizeof(*s->dsdctx))
1002         return AVERROR(EINVAL);
1003 
1004     s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
1005     if (!s->dsd_ref)
1006         return AVERROR(ENOMEM);
1007     s->dsdctx = (DSDContext*)s->dsd_ref->data;
1008     s->dsd_channels = channels;
1009 
1010     for (i = 0; i < channels; i++)
1011         memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1012 
1013     return 0;
1014 }
1015 
1016 #if HAVE_THREADS
update_thread_context(AVCodecContext *dst, const AVCodecContext *src)1017 static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
1018 {
1019     WavpackContext *fsrc = src->priv_data;
1020     WavpackContext *fdst = dst->priv_data;
1021     int ret;
1022 
1023     if (dst == src)
1024         return 0;
1025 
1026     ff_thread_release_ext_buffer(dst, &fdst->curr_frame);
1027     if (fsrc->curr_frame.f->data[0]) {
1028         if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1029             return ret;
1030     }
1031 
1032     fdst->dsdctx = NULL;
1033     fdst->dsd_channels = 0;
1034     ret = av_buffer_replace(&fdst->dsd_ref, fsrc->dsd_ref);
1035     if (ret < 0)
1036         return ret;
1037     if (fsrc->dsd_ref) {
1038         fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
1039         fdst->dsd_channels = fsrc->dsd_channels;
1040     }
1041 
1042     return 0;
1043 }
1044 #endif
1045 
wavpack_decode_init(AVCodecContext *avctx)1046 static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1047 {
1048     WavpackContext *s = avctx->priv_data;
1049 
1050     s->avctx = avctx;
1051 
1052     s->fdec_num = 0;
1053 
1054     s->curr_frame.f = av_frame_alloc();
1055     s->prev_frame.f = av_frame_alloc();
1056 
1057     if (!s->curr_frame.f || !s->prev_frame.f)
1058         return AVERROR(ENOMEM);
1059 
1060     ff_init_dsd_data();
1061 
1062     return 0;
1063 }
1064 
wavpack_decode_end(AVCodecContext *avctx)1065 static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1066 {
1067     WavpackContext *s = avctx->priv_data;
1068 
1069     for (int i = 0; i < s->fdec_num; i++)
1070         av_freep(&s->fdec[i]);
1071     s->fdec_num = 0;
1072 
1073     ff_thread_release_ext_buffer(avctx, &s->curr_frame);
1074     av_frame_free(&s->curr_frame.f);
1075 
1076     ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1077     av_frame_free(&s->prev_frame.f);
1078 
1079     av_buffer_unref(&s->dsd_ref);
1080 
1081     return 0;
1082 }
1083 
wavpack_decode_block(AVCodecContext *avctx, int block_no, const uint8_t *buf, int buf_size)1084 static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1085                                 const uint8_t *buf, int buf_size)
1086 {
1087     WavpackContext *wc = avctx->priv_data;
1088     WavpackFrameContext *s;
1089     GetByteContext gb;
1090     enum AVSampleFormat sample_fmt;
1091     void *samples_l = NULL, *samples_r = NULL;
1092     int ret;
1093     int got_terms   = 0, got_weights = 0, got_samples = 0,
1094         got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1095     int got_dsd = 0;
1096     int i, j, id, size, ssize, weights, t;
1097     int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1098     int multiblock;
1099     uint64_t chmask = 0;
1100 
1101     if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1102         av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1103         return AVERROR_INVALIDDATA;
1104     }
1105 
1106     s = wc->fdec[block_no];
1107     if (!s) {
1108         av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1109                block_no);
1110         return AVERROR_INVALIDDATA;
1111     }
1112 
1113     memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1114     memset(s->ch, 0, sizeof(s->ch));
1115     s->extra_bits     = 0;
1116     s->and            = s->or = s->shift = 0;
1117     s->got_extra_bits = 0;
1118 
1119     bytestream2_init(&gb, buf, buf_size);
1120 
1121     s->samples = bytestream2_get_le32(&gb);
1122     if (s->samples != wc->samples) {
1123         av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1124                "a sequence: %d and %d\n", wc->samples, s->samples);
1125         return AVERROR_INVALIDDATA;
1126     }
1127     s->frame_flags = bytestream2_get_le32(&gb);
1128 
1129     if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
1130         sample_fmt = AV_SAMPLE_FMT_FLTP;
1131     else if ((s->frame_flags & 0x03) <= 1)
1132         sample_fmt = AV_SAMPLE_FMT_S16P;
1133     else
1134         sample_fmt          = AV_SAMPLE_FMT_S32P;
1135 
1136     if (wc->ch_offset && avctx->sample_fmt != sample_fmt)
1137         return AVERROR_INVALIDDATA;
1138 
1139     bpp            = av_get_bytes_per_sample(sample_fmt);
1140     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1141     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1142 
1143     s->stereo         = !(s->frame_flags & WV_MONO);
1144     s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1145     s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1146     s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1147     s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1148     s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1149     if (s->post_shift < 0 || s->post_shift > 31) {
1150         return AVERROR_INVALIDDATA;
1151     }
1152     s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1153     s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1154     s->CRC            = bytestream2_get_le32(&gb);
1155 
1156     // parse metadata blocks
1157     while (bytestream2_get_bytes_left(&gb)) {
1158         id   = bytestream2_get_byte(&gb);
1159         size = bytestream2_get_byte(&gb);
1160         if (id & WP_IDF_LONG)
1161             size |= (bytestream2_get_le16u(&gb)) << 8;
1162         size <<= 1; // size is specified in words
1163         ssize  = size;
1164         if (id & WP_IDF_ODD)
1165             size--;
1166         if (size < 0) {
1167             av_log(avctx, AV_LOG_ERROR,
1168                    "Got incorrect block %02X with size %i\n", id, size);
1169             break;
1170         }
1171         if (bytestream2_get_bytes_left(&gb) < ssize) {
1172             av_log(avctx, AV_LOG_ERROR,
1173                    "Block size %i is out of bounds\n", size);
1174             break;
1175         }
1176         switch (id & WP_IDF_MASK) {
1177         case WP_ID_DECTERMS:
1178             if (size > MAX_TERMS) {
1179                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1180                 s->terms = 0;
1181                 bytestream2_skip(&gb, ssize);
1182                 continue;
1183             }
1184             s->terms = size;
1185             for (i = 0; i < s->terms; i++) {
1186                 uint8_t val = bytestream2_get_byte(&gb);
1187                 s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1188                 s->decorr[s->terms - i - 1].delta =  val >> 5;
1189             }
1190             got_terms = 1;
1191             break;
1192         case WP_ID_DECWEIGHTS:
1193             if (!got_terms) {
1194                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1195                 continue;
1196             }
1197             weights = size >> s->stereo_in;
1198             if (weights > MAX_TERMS || weights > s->terms) {
1199                 av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1200                 bytestream2_skip(&gb, ssize);
1201                 continue;
1202             }
1203             for (i = 0; i < weights; i++) {
1204                 t = (int8_t)bytestream2_get_byte(&gb);
1205                 s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1206                 if (s->decorr[s->terms - i - 1].weightA > 0)
1207                     s->decorr[s->terms - i - 1].weightA +=
1208                         (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1209                 if (s->stereo_in) {
1210                     t = (int8_t)bytestream2_get_byte(&gb);
1211                     s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1212                     if (s->decorr[s->terms - i - 1].weightB > 0)
1213                         s->decorr[s->terms - i - 1].weightB +=
1214                             (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1215                 }
1216             }
1217             got_weights = 1;
1218             break;
1219         case WP_ID_DECSAMPLES:
1220             if (!got_terms) {
1221                 av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1222                 continue;
1223             }
1224             t = 0;
1225             for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1226                 if (s->decorr[i].value > 8) {
1227                     s->decorr[i].samplesA[0] =
1228                         wp_exp2(bytestream2_get_le16(&gb));
1229                     s->decorr[i].samplesA[1] =
1230                         wp_exp2(bytestream2_get_le16(&gb));
1231 
1232                     if (s->stereo_in) {
1233                         s->decorr[i].samplesB[0] =
1234                             wp_exp2(bytestream2_get_le16(&gb));
1235                         s->decorr[i].samplesB[1] =
1236                             wp_exp2(bytestream2_get_le16(&gb));
1237                         t                       += 4;
1238                     }
1239                     t += 4;
1240                 } else if (s->decorr[i].value < 0) {
1241                     s->decorr[i].samplesA[0] =
1242                         wp_exp2(bytestream2_get_le16(&gb));
1243                     s->decorr[i].samplesB[0] =
1244                         wp_exp2(bytestream2_get_le16(&gb));
1245                     t                       += 4;
1246                 } else {
1247                     for (j = 0; j < s->decorr[i].value; j++) {
1248                         s->decorr[i].samplesA[j] =
1249                             wp_exp2(bytestream2_get_le16(&gb));
1250                         if (s->stereo_in) {
1251                             s->decorr[i].samplesB[j] =
1252                                 wp_exp2(bytestream2_get_le16(&gb));
1253                         }
1254                     }
1255                     t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1256                 }
1257             }
1258             got_samples = 1;
1259             break;
1260         case WP_ID_ENTROPY:
1261             if (size != 6 * (s->stereo_in + 1)) {
1262                 av_log(avctx, AV_LOG_ERROR,
1263                        "Entropy vars size should be %i, got %i.\n",
1264                        6 * (s->stereo_in + 1), size);
1265                 bytestream2_skip(&gb, ssize);
1266                 continue;
1267             }
1268             for (j = 0; j <= s->stereo_in; j++)
1269                 for (i = 0; i < 3; i++) {
1270                     s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1271                 }
1272             got_entropy = 1;
1273             break;
1274         case WP_ID_HYBRID:
1275             if (s->hybrid_bitrate) {
1276                 for (i = 0; i <= s->stereo_in; i++) {
1277                     s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1278                     size               -= 2;
1279                 }
1280             }
1281             for (i = 0; i < (s->stereo_in + 1); i++) {
1282                 s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1283                 size                -= 2;
1284             }
1285             if (size > 0) {
1286                 for (i = 0; i < (s->stereo_in + 1); i++) {
1287                     s->ch[i].bitrate_delta =
1288                         wp_exp2((int16_t)bytestream2_get_le16(&gb));
1289                 }
1290             } else {
1291                 for (i = 0; i < (s->stereo_in + 1); i++)
1292                     s->ch[i].bitrate_delta = 0;
1293             }
1294             got_hybrid = 1;
1295             break;
1296         case WP_ID_INT32INFO: {
1297             uint8_t val[4];
1298             if (size != 4) {
1299                 av_log(avctx, AV_LOG_ERROR,
1300                        "Invalid INT32INFO, size = %i\n",
1301                        size);
1302                 bytestream2_skip(&gb, ssize - 4);
1303                 continue;
1304             }
1305             bytestream2_get_buffer(&gb, val, 4);
1306             if (val[0] > 30) {
1307                 av_log(avctx, AV_LOG_ERROR,
1308                        "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1309                 continue;
1310             } else {
1311                 s->extra_bits = val[0];
1312             }
1313             if (val[1])
1314                 s->shift = val[1];
1315             if (val[2]) {
1316                 s->and   = s->or = 1;
1317                 s->shift = val[2];
1318             }
1319             if (val[3]) {
1320                 s->and   = 1;
1321                 s->shift = val[3];
1322             }
1323             if (s->shift > 31) {
1324                 av_log(avctx, AV_LOG_ERROR,
1325                        "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1326                 s->and = s->or = s->shift = 0;
1327                 continue;
1328             }
1329             /* original WavPack decoder forces 32-bit lossy sound to be treated
1330              * as 24-bit one in order to have proper clipping */
1331             if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1332                 s->post_shift      += 8;
1333                 s->shift           -= 8;
1334                 s->hybrid_maxclip >>= 8;
1335                 s->hybrid_minclip >>= 8;
1336             }
1337             break;
1338         }
1339         case WP_ID_FLOATINFO:
1340             if (size != 4) {
1341                 av_log(avctx, AV_LOG_ERROR,
1342                        "Invalid FLOATINFO, size = %i\n", size);
1343                 bytestream2_skip(&gb, ssize);
1344                 continue;
1345             }
1346             s->float_flag    = bytestream2_get_byte(&gb);
1347             s->float_shift   = bytestream2_get_byte(&gb);
1348             s->float_max_exp = bytestream2_get_byte(&gb);
1349             if (s->float_shift > 31) {
1350                 av_log(avctx, AV_LOG_ERROR,
1351                        "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1352                 s->float_shift = 0;
1353                 continue;
1354             }
1355             got_float        = 1;
1356             bytestream2_skip(&gb, 1);
1357             break;
1358         case WP_ID_DATA:
1359             if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1360                 return ret;
1361             bytestream2_skip(&gb, size);
1362             got_pcm      = 1;
1363             break;
1364         case WP_ID_DSD_DATA:
1365             if (size < 2) {
1366                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1367                        size);
1368                 bytestream2_skip(&gb, ssize);
1369                 continue;
1370             }
1371             rate_x = bytestream2_get_byte(&gb);
1372             if (rate_x > 30)
1373                 return AVERROR_INVALIDDATA;
1374             rate_x = 1 << rate_x;
1375             dsd_mode = bytestream2_get_byte(&gb);
1376             if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1377                 av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1378                     dsd_mode);
1379                 return AVERROR_INVALIDDATA;
1380             }
1381             bytestream2_init(&s->gbyte, gb.buffer, size-2);
1382             bytestream2_skip(&gb, size-2);
1383             got_dsd      = 1;
1384             break;
1385         case WP_ID_EXTRABITS:
1386             if (size <= 4) {
1387                 av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1388                        size);
1389                 bytestream2_skip(&gb, size);
1390                 continue;
1391             }
1392             if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1393                 return ret;
1394             s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1395             bytestream2_skip(&gb, size);
1396             s->got_extra_bits  = 1;
1397             break;
1398         case WP_ID_CHANINFO:
1399             if (size <= 1) {
1400                 av_log(avctx, AV_LOG_ERROR,
1401                        "Insufficient channel information\n");
1402                 return AVERROR_INVALIDDATA;
1403             }
1404             chan = bytestream2_get_byte(&gb);
1405             switch (size - 2) {
1406             case 0:
1407                 chmask = bytestream2_get_byte(&gb);
1408                 break;
1409             case 1:
1410                 chmask = bytestream2_get_le16(&gb);
1411                 break;
1412             case 2:
1413                 chmask = bytestream2_get_le24(&gb);
1414                 break;
1415             case 3:
1416                 chmask = bytestream2_get_le32(&gb);
1417                 break;
1418             case 4:
1419                 size = bytestream2_get_byte(&gb);
1420                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1421                 chan  += 1;
1422                 if (avctx->ch_layout.nb_channels != chan)
1423                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1424                            " instead of %i.\n", chan, avctx->ch_layout.nb_channels);
1425                 chmask = bytestream2_get_le24(&gb);
1426                 break;
1427             case 5:
1428                 size = bytestream2_get_byte(&gb);
1429                 chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1430                 chan  += 1;
1431                 if (avctx->ch_layout.nb_channels != chan)
1432                     av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1433                            " instead of %i.\n", chan, avctx->ch_layout.nb_channels);
1434                 chmask = bytestream2_get_le32(&gb);
1435                 break;
1436             default:
1437                 av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1438                        size);
1439             }
1440             break;
1441         case WP_ID_SAMPLE_RATE:
1442             if (size != 3) {
1443                 av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1444                 return AVERROR_INVALIDDATA;
1445             }
1446             sample_rate = bytestream2_get_le24(&gb);
1447             break;
1448         default:
1449             bytestream2_skip(&gb, size);
1450         }
1451         if (id & WP_IDF_ODD)
1452             bytestream2_skip(&gb, 1);
1453     }
1454 
1455     if (got_pcm) {
1456         if (!got_terms) {
1457             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1458             return AVERROR_INVALIDDATA;
1459         }
1460         if (!got_weights) {
1461             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1462             return AVERROR_INVALIDDATA;
1463         }
1464         if (!got_samples) {
1465             av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1466             return AVERROR_INVALIDDATA;
1467         }
1468         if (!got_entropy) {
1469             av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1470             return AVERROR_INVALIDDATA;
1471         }
1472         if (s->hybrid && !got_hybrid) {
1473             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1474             return AVERROR_INVALIDDATA;
1475         }
1476         if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
1477             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1478             return AVERROR_INVALIDDATA;
1479         }
1480         if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
1481             const int size   = get_bits_left(&s->gb_extra_bits);
1482             const int wanted = s->samples * s->extra_bits << s->stereo_in;
1483             if (size < wanted) {
1484                 av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1485                 s->got_extra_bits = 0;
1486             }
1487         }
1488     }
1489 
1490     if (!got_pcm && !got_dsd) {
1491         av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1492         return AVERROR_INVALIDDATA;
1493     }
1494 
1495     if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1496         (got_dsd && wc->modulation != MODULATION_DSD)) {
1497             av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1498             return AVERROR_INVALIDDATA;
1499     }
1500 
1501     if (!wc->ch_offset) {
1502         AVChannelLayout new_ch_layout = { 0 };
1503         int new_samplerate;
1504         int sr = (s->frame_flags >> 23) & 0xf;
1505         if (sr == 0xf) {
1506             if (!sample_rate) {
1507                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1508                 return AVERROR_INVALIDDATA;
1509             }
1510             new_samplerate = sample_rate;
1511         } else
1512             new_samplerate = wv_rates[sr];
1513 
1514         if (new_samplerate * (uint64_t)rate_x > INT_MAX)
1515             return AVERROR_INVALIDDATA;
1516         new_samplerate *= rate_x;
1517 
1518         if (multiblock) {
1519             if (chmask) {
1520                 av_channel_layout_from_mask(&new_ch_layout, chmask);
1521                 if (chan && new_ch_layout.nb_channels != chan) {
1522                     av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
1523                     return AVERROR_INVALIDDATA;
1524                 }
1525             } else {
1526                 ret = av_channel_layout_copy(&new_ch_layout, &avctx->ch_layout);
1527                 if (ret < 0) {
1528                     av_log(avctx, AV_LOG_ERROR, "Error copying channel layout\n");
1529                     return ret;
1530                 }
1531             }
1532         } else {
1533             av_channel_layout_default(&new_ch_layout, s->stereo + 1);
1534         }
1535 
1536         /* clear DSD state if stream properties change */
1537         if (new_ch_layout.nb_channels != wc->dsd_channels ||
1538             av_channel_layout_compare(&new_ch_layout, &avctx->ch_layout) ||
1539             new_samplerate != avctx->sample_rate    ||
1540             !!got_dsd      != !!wc->dsdctx) {
1541             ret = wv_dsd_reset(wc, got_dsd ? new_ch_layout.nb_channels : 0);
1542             if (ret < 0) {
1543                 av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
1544                 return ret;
1545             }
1546             ff_thread_release_ext_buffer(avctx, &wc->curr_frame);
1547         }
1548         av_channel_layout_copy(&avctx->ch_layout, &new_ch_layout);
1549         avctx->sample_rate         = new_samplerate;
1550         avctx->sample_fmt          = sample_fmt;
1551         avctx->bits_per_raw_sample = orig_bpp;
1552 
1553         ff_thread_release_ext_buffer(avctx, &wc->prev_frame);
1554         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1555 
1556         /* get output buffer */
1557         wc->curr_frame.f->nb_samples = s->samples;
1558         ret = ff_thread_get_ext_buffer(avctx, &wc->curr_frame,
1559                                        AV_GET_BUFFER_FLAG_REF);
1560         if (ret < 0)
1561             return ret;
1562 
1563         wc->frame = wc->curr_frame.f;
1564         ff_thread_finish_setup(avctx);
1565     }
1566 
1567     if (wc->ch_offset + s->stereo >= avctx->ch_layout.nb_channels) {
1568         av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1569         return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1570     }
1571 
1572     samples_l = wc->frame->extended_data[wc->ch_offset];
1573     if (s->stereo)
1574         samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1575 
1576     wc->ch_offset += 1 + s->stereo;
1577 
1578     if (s->stereo_in) {
1579         if (got_dsd) {
1580             if (dsd_mode == 3) {
1581                 ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1582             } else if (dsd_mode == 1) {
1583                 ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1584             } else {
1585                 ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1586             }
1587         } else {
1588             ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1589         }
1590         if (ret < 0)
1591             return ret;
1592     } else {
1593         if (got_dsd) {
1594             if (dsd_mode == 3) {
1595                 ret = wv_unpack_dsd_high(s, samples_l, NULL);
1596             } else if (dsd_mode == 1) {
1597                 ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1598             } else {
1599                 ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1600             }
1601         } else {
1602             ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1603         }
1604         if (ret < 0)
1605             return ret;
1606 
1607         if (s->stereo)
1608             memcpy(samples_r, samples_l, bpp * s->samples);
1609     }
1610 
1611     return 0;
1612 }
1613 
wavpack_decode_flush(AVCodecContext *avctx)1614 static void wavpack_decode_flush(AVCodecContext *avctx)
1615 {
1616     WavpackContext *s = avctx->priv_data;
1617 
1618     wv_dsd_reset(s, 0);
1619 }
1620 
dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)1621 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1622 {
1623     WavpackContext *s  = avctx->priv_data;
1624     AVFrame *frame = frmptr;
1625 
1626     ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1627         (uint8_t *)frame->extended_data[jobnr], 4,
1628         (float *)frame->extended_data[jobnr], 1);
1629 
1630     return 0;
1631 }
1632 
wavpack_decode_frame(AVCodecContext *avctx, AVFrame *rframe, int *got_frame_ptr, AVPacket *avpkt)1633 static int wavpack_decode_frame(AVCodecContext *avctx, AVFrame *rframe,
1634                                 int *got_frame_ptr, AVPacket *avpkt)
1635 {
1636     WavpackContext *s  = avctx->priv_data;
1637     const uint8_t *buf = avpkt->data;
1638     int buf_size       = avpkt->size;
1639     int frame_size, ret, frame_flags;
1640 
1641     if (avpkt->size <= WV_HEADER_SIZE)
1642         return AVERROR_INVALIDDATA;
1643 
1644     s->frame     = NULL;
1645     s->block     = 0;
1646     s->ch_offset = 0;
1647 
1648     /* determine number of samples */
1649     s->samples  = AV_RL32(buf + 20);
1650     frame_flags = AV_RL32(buf + 24);
1651     if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1652         av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1653                s->samples);
1654         return AVERROR_INVALIDDATA;
1655     }
1656 
1657     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1658 
1659     while (buf_size > WV_HEADER_SIZE) {
1660         frame_size = AV_RL32(buf + 4) - 12;
1661         buf       += 20;
1662         buf_size  -= 20;
1663         if (frame_size <= 0 || frame_size > buf_size) {
1664             av_log(avctx, AV_LOG_ERROR,
1665                    "Block %d has invalid size (size %d vs. %d bytes left)\n",
1666                    s->block, frame_size, buf_size);
1667             ret = AVERROR_INVALIDDATA;
1668             goto error;
1669         }
1670         if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1671             goto error;
1672         s->block++;
1673         buf      += frame_size;
1674         buf_size -= frame_size;
1675     }
1676 
1677     if (s->ch_offset != avctx->ch_layout.nb_channels) {
1678         av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1679         ret = AVERROR_INVALIDDATA;
1680         goto error;
1681     }
1682 
1683     ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1684     ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1685 
1686     if (s->modulation == MODULATION_DSD)
1687         avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->ch_layout.nb_channels);
1688 
1689     ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1690 
1691     if ((ret = av_frame_ref(rframe, s->frame)) < 0)
1692         return ret;
1693 
1694     *got_frame_ptr = 1;
1695 
1696     return avpkt->size;
1697 
1698 error:
1699     if (s->frame) {
1700         ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1701         ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1702         ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1703     }
1704 
1705     return ret;
1706 }
1707 
1708 const FFCodec ff_wavpack_decoder = {
1709     .p.name         = "wavpack",
1710     .p.long_name    = NULL_IF_CONFIG_SMALL("WavPack"),
1711     .p.type         = AVMEDIA_TYPE_AUDIO,
1712     .p.id           = AV_CODEC_ID_WAVPACK,
1713     .priv_data_size = sizeof(WavpackContext),
1714     .init           = wavpack_decode_init,
1715     .close          = wavpack_decode_end,
1716     FF_CODEC_DECODE_CB(wavpack_decode_frame),
1717     .flush          = wavpack_decode_flush,
1718     .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1719     .p.capabilities = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1720                       AV_CODEC_CAP_SLICE_THREADS | AV_CODEC_CAP_CHANNEL_CONF,
1721     .caps_internal  = FF_CODEC_CAP_INIT_THREADSAFE | FF_CODEC_CAP_INIT_CLEANUP |
1722                       FF_CODEC_CAP_ALLOCATE_PROGRESS,
1723 };
1724