xref: /third_party/ffmpeg/libavcodec/wavpack.c (revision cabdff1a)
1/*
2 * WavPack lossless audio decoder
3 * Copyright (c) 2006,2011 Konstantin Shishkov
4 * Copyright (c) 2020 David Bryant
5 *
6 * This file is part of FFmpeg.
7 *
8 * FFmpeg is free software; you can redistribute it and/or
9 * modify it under the terms of the GNU Lesser General Public
10 * License as published by the Free Software Foundation; either
11 * version 2.1 of the License, or (at your option) any later version.
12 *
13 * FFmpeg is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
16 * Lesser General Public License for more details.
17 *
18 * You should have received a copy of the GNU Lesser General Public
19 * License along with FFmpeg; if not, write to the Free Software
20 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
21 */
22
23#include "libavutil/buffer.h"
24#include "libavutil/channel_layout.h"
25
26#define BITSTREAM_READER_LE
27#include "avcodec.h"
28#include "bytestream.h"
29#include "codec_internal.h"
30#include "get_bits.h"
31#include "thread.h"
32#include "threadframe.h"
33#include "unary.h"
34#include "wavpack.h"
35#include "dsd.h"
36
37/**
38 * @file
39 * WavPack lossless audio decoder
40 */
41
42#define DSD_BYTE_READY(low,high) (!(((low) ^ (high)) & 0xff000000))
43
44#define PTABLE_BITS 8
45#define PTABLE_BINS (1<<PTABLE_BITS)
46#define PTABLE_MASK (PTABLE_BINS-1)
47
48#define UP   0x010000fe
49#define DOWN 0x00010000
50#define DECAY 8
51
52#define PRECISION 20
53#define VALUE_ONE (1 << PRECISION)
54#define PRECISION_USE 12
55
56#define RATE_S 20
57
58#define MAX_HISTORY_BITS    5
59#define MAX_HISTORY_BINS    (1 << MAX_HISTORY_BITS)
60#define MAX_BIN_BYTES       1280    // for value_lookup, per bin (2k - 512 - 256)
61
62typedef enum {
63    MODULATION_PCM,     // pulse code modulation
64    MODULATION_DSD      // pulse density modulation (aka DSD)
65} Modulation;
66
67typedef struct WavpackFrameContext {
68    AVCodecContext *avctx;
69    int frame_flags;
70    int stereo, stereo_in;
71    int joint;
72    uint32_t CRC;
73    GetBitContext gb;
74    int got_extra_bits;
75    uint32_t crc_extra_bits;
76    GetBitContext gb_extra_bits;
77    int samples;
78    int terms;
79    Decorr decorr[MAX_TERMS];
80    int zero, one, zeroes;
81    int extra_bits;
82    int and, or, shift;
83    int post_shift;
84    int hybrid, hybrid_bitrate;
85    int hybrid_maxclip, hybrid_minclip;
86    int float_flag;
87    int float_shift;
88    int float_max_exp;
89    WvChannel ch[2];
90
91    GetByteContext gbyte;
92    int ptable [PTABLE_BINS];
93    uint8_t value_lookup_buffer[MAX_HISTORY_BINS*MAX_BIN_BYTES];
94    uint16_t summed_probabilities[MAX_HISTORY_BINS][256];
95    uint8_t probabilities[MAX_HISTORY_BINS][256];
96    uint8_t *value_lookup[MAX_HISTORY_BINS];
97} WavpackFrameContext;
98
99#define WV_MAX_FRAME_DECODERS 14
100
101typedef struct WavpackContext {
102    AVCodecContext *avctx;
103
104    WavpackFrameContext *fdec[WV_MAX_FRAME_DECODERS];
105    int fdec_num;
106
107    int block;
108    int samples;
109    int ch_offset;
110
111    AVFrame *frame;
112    ThreadFrame curr_frame, prev_frame;
113    Modulation modulation;
114
115    AVBufferRef *dsd_ref;
116    DSDContext *dsdctx;
117    int dsd_channels;
118} WavpackContext;
119
120#define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
121
122static av_always_inline unsigned get_tail(GetBitContext *gb, int k)
123{
124    int p, e, res;
125
126    if (k < 1)
127        return 0;
128    p   = av_log2(k);
129    e   = (1 << (p + 1)) - k - 1;
130    res = get_bitsz(gb, p);
131    if (res >= e)
132        res = res * 2U - e + get_bits1(gb);
133    return res;
134}
135
136static int update_error_limit(WavpackFrameContext *ctx)
137{
138    int i, br[2], sl[2];
139
140    for (i = 0; i <= ctx->stereo_in; i++) {
141        if (ctx->ch[i].bitrate_acc > UINT_MAX - ctx->ch[i].bitrate_delta)
142            return AVERROR_INVALIDDATA;
143        ctx->ch[i].bitrate_acc += ctx->ch[i].bitrate_delta;
144        br[i]                   = ctx->ch[i].bitrate_acc >> 16;
145        sl[i]                   = LEVEL_DECAY(ctx->ch[i].slow_level);
146    }
147    if (ctx->stereo_in && ctx->hybrid_bitrate) {
148        int balance = (sl[1] - sl[0] + br[1] + 1) >> 1;
149        if (balance > br[0]) {
150            br[1] = br[0] * 2;
151            br[0] = 0;
152        } else if (-balance > br[0]) {
153            br[0]  *= 2;
154            br[1]   = 0;
155        } else {
156            br[1] = br[0] + balance;
157            br[0] = br[0] - balance;
158        }
159    }
160    for (i = 0; i <= ctx->stereo_in; i++) {
161        if (ctx->hybrid_bitrate) {
162            if (sl[i] - br[i] > -0x100)
163                ctx->ch[i].error_limit = wp_exp2(sl[i] - br[i] + 0x100);
164            else
165                ctx->ch[i].error_limit = 0;
166        } else {
167            ctx->ch[i].error_limit = wp_exp2(br[i]);
168        }
169    }
170
171    return 0;
172}
173
174static int wv_get_value(WavpackFrameContext *ctx, GetBitContext *gb,
175                        int channel, int *last)
176{
177    int t, t2;
178    int sign, base, add, ret;
179    WvChannel *c = &ctx->ch[channel];
180
181    *last = 0;
182
183    if ((ctx->ch[0].median[0] < 2U) && (ctx->ch[1].median[0] < 2U) &&
184        !ctx->zero && !ctx->one) {
185        if (ctx->zeroes) {
186            ctx->zeroes--;
187            if (ctx->zeroes) {
188                c->slow_level -= LEVEL_DECAY(c->slow_level);
189                return 0;
190            }
191        } else {
192            t = get_unary_0_33(gb);
193            if (t >= 2) {
194                if (t >= 32 || get_bits_left(gb) < t - 1)
195                    goto error;
196                t = get_bits_long(gb, t - 1) | (1 << (t - 1));
197            } else {
198                if (get_bits_left(gb) < 0)
199                    goto error;
200            }
201            ctx->zeroes = t;
202            if (ctx->zeroes) {
203                memset(ctx->ch[0].median, 0, sizeof(ctx->ch[0].median));
204                memset(ctx->ch[1].median, 0, sizeof(ctx->ch[1].median));
205                c->slow_level -= LEVEL_DECAY(c->slow_level);
206                return 0;
207            }
208        }
209    }
210
211    if (ctx->zero) {
212        t         = 0;
213        ctx->zero = 0;
214    } else {
215        t = get_unary_0_33(gb);
216        if (get_bits_left(gb) < 0)
217            goto error;
218        if (t == 16) {
219            t2 = get_unary_0_33(gb);
220            if (t2 < 2) {
221                if (get_bits_left(gb) < 0)
222                    goto error;
223                t += t2;
224            } else {
225                if (t2 >= 32 || get_bits_left(gb) < t2 - 1)
226                    goto error;
227                t += get_bits_long(gb, t2 - 1) | (1 << (t2 - 1));
228            }
229        }
230
231        if (ctx->one) {
232            ctx->one = t & 1;
233            t        = (t >> 1) + 1;
234        } else {
235            ctx->one = t & 1;
236            t      >>= 1;
237        }
238        ctx->zero = !ctx->one;
239    }
240
241    if (ctx->hybrid && !channel) {
242        if (update_error_limit(ctx) < 0)
243            goto error;
244    }
245
246    if (!t) {
247        base = 0;
248        add  = GET_MED(0) - 1;
249        DEC_MED(0);
250    } else if (t == 1) {
251        base = GET_MED(0);
252        add  = GET_MED(1) - 1;
253        INC_MED(0);
254        DEC_MED(1);
255    } else if (t == 2) {
256        base = GET_MED(0) + GET_MED(1);
257        add  = GET_MED(2) - 1;
258        INC_MED(0);
259        INC_MED(1);
260        DEC_MED(2);
261    } else {
262        base = GET_MED(0) + GET_MED(1) + GET_MED(2) * (t - 2U);
263        add  = GET_MED(2) - 1;
264        INC_MED(0);
265        INC_MED(1);
266        INC_MED(2);
267    }
268    if (!c->error_limit) {
269        if (add >= 0x2000000U) {
270            av_log(ctx->avctx, AV_LOG_ERROR, "k %d is too large\n", add);
271            goto error;
272        }
273        ret = base + get_tail(gb, add);
274        if (get_bits_left(gb) <= 0)
275            goto error;
276    } else {
277        int mid = (base * 2U + add + 1) >> 1;
278        while (add > c->error_limit) {
279            if (get_bits_left(gb) <= 0)
280                goto error;
281            if (get_bits1(gb)) {
282                add -= (mid - (unsigned)base);
283                base = mid;
284            } else
285                add = mid - (unsigned)base - 1;
286            mid = (base * 2U + add + 1) >> 1;
287        }
288        ret = mid;
289    }
290    sign = get_bits1(gb);
291    if (ctx->hybrid_bitrate)
292        c->slow_level += wp_log2(ret) - LEVEL_DECAY(c->slow_level);
293    return sign ? ~ret : ret;
294
295error:
296    ret = get_bits_left(gb);
297    if (ret <= 0) {
298        av_log(ctx->avctx, AV_LOG_ERROR, "Too few bits (%d) left\n", ret);
299    }
300    *last = 1;
301    return 0;
302}
303
304static inline int wv_get_value_integer(WavpackFrameContext *s, uint32_t *crc,
305                                       unsigned S)
306{
307    unsigned bit;
308
309    if (s->extra_bits) {
310        S *= 1 << s->extra_bits;
311
312        if (s->got_extra_bits &&
313            get_bits_left(&s->gb_extra_bits) >= s->extra_bits) {
314            S   |= get_bits_long(&s->gb_extra_bits, s->extra_bits);
315            *crc = *crc * 9 + (S & 0xffff) * 3 + ((unsigned)S >> 16);
316        }
317    }
318
319    bit = (S & s->and) | s->or;
320    bit = ((S + bit) << s->shift) - bit;
321
322    if (s->hybrid)
323        bit = av_clip(bit, s->hybrid_minclip, s->hybrid_maxclip);
324
325    return bit << s->post_shift;
326}
327
328static float wv_get_value_float(WavpackFrameContext *s, uint32_t *crc, int S)
329{
330    union {
331        float    f;
332        uint32_t u;
333    } value;
334
335    unsigned int sign;
336    int exp = s->float_max_exp;
337
338    if (s->got_extra_bits) {
339        const int max_bits  = 1 + 23 + 8 + 1;
340        const int left_bits = get_bits_left(&s->gb_extra_bits);
341
342        if (left_bits + 8 * AV_INPUT_BUFFER_PADDING_SIZE < max_bits)
343            return 0.0;
344    }
345
346    if (S) {
347        S  *= 1U << s->float_shift;
348        sign = S < 0;
349        if (sign)
350            S = -(unsigned)S;
351        if (S >= 0x1000000U) {
352            if (s->got_extra_bits && get_bits1(&s->gb_extra_bits))
353                S = get_bits(&s->gb_extra_bits, 23);
354            else
355                S = 0;
356            exp = 255;
357        } else if (exp) {
358            int shift = 23 - av_log2(S);
359            exp = s->float_max_exp;
360            if (exp <= shift)
361                shift = --exp;
362            exp -= shift;
363
364            if (shift) {
365                S <<= shift;
366                if ((s->float_flag & WV_FLT_SHIFT_ONES) ||
367                    (s->got_extra_bits &&
368                     (s->float_flag & WV_FLT_SHIFT_SAME) &&
369                     get_bits1(&s->gb_extra_bits))) {
370                    S |= (1 << shift) - 1;
371                } else if (s->got_extra_bits &&
372                           (s->float_flag & WV_FLT_SHIFT_SENT)) {
373                    S |= get_bits(&s->gb_extra_bits, shift);
374                }
375            }
376        } else {
377            exp = s->float_max_exp;
378        }
379        S &= 0x7fffff;
380    } else {
381        sign = 0;
382        exp  = 0;
383        if (s->got_extra_bits && (s->float_flag & WV_FLT_ZERO_SENT)) {
384            if (get_bits1(&s->gb_extra_bits)) {
385                S = get_bits(&s->gb_extra_bits, 23);
386                if (s->float_max_exp >= 25)
387                    exp = get_bits(&s->gb_extra_bits, 8);
388                sign = get_bits1(&s->gb_extra_bits);
389            } else {
390                if (s->float_flag & WV_FLT_ZERO_SIGN)
391                    sign = get_bits1(&s->gb_extra_bits);
392            }
393        }
394    }
395
396    *crc = *crc * 27 + S * 9 + exp * 3 + sign;
397
398    value.u = (sign << 31) | (exp << 23) | S;
399    return value.f;
400}
401
402static inline int wv_check_crc(WavpackFrameContext *s, uint32_t crc,
403                               uint32_t crc_extra_bits)
404{
405    if (crc != s->CRC) {
406        av_log(s->avctx, AV_LOG_ERROR, "CRC error\n");
407        return AVERROR_INVALIDDATA;
408    }
409    if (s->got_extra_bits && crc_extra_bits != s->crc_extra_bits) {
410        av_log(s->avctx, AV_LOG_ERROR, "Extra bits CRC error\n");
411        return AVERROR_INVALIDDATA;
412    }
413
414    return 0;
415}
416
417static void init_ptable(int *table, int rate_i, int rate_s)
418{
419    int value = 0x808000, rate = rate_i << 8;
420
421    for (int c = (rate + 128) >> 8; c--;)
422        value += (DOWN - value) >> DECAY;
423
424    for (int i = 0; i < PTABLE_BINS/2; i++) {
425        table[i] = value;
426        table[PTABLE_BINS-1-i] = 0x100ffff - value;
427
428        if (value > 0x010000) {
429            rate += (rate * rate_s + 128) >> 8;
430
431            for (int c = (rate + 64) >> 7; c--;)
432                value += (DOWN - value) >> DECAY;
433        }
434    }
435}
436
437typedef struct {
438    int32_t value, fltr0, fltr1, fltr2, fltr3, fltr4, fltr5, fltr6, factor;
439    unsigned int byte;
440} DSDfilters;
441
442static int wv_unpack_dsd_high(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
443{
444    uint32_t checksum = 0xFFFFFFFF;
445    uint8_t *dst_l = dst_left, *dst_r = dst_right;
446    int total_samples = s->samples, stereo = dst_r ? 1 : 0;
447    DSDfilters filters[2], *sp = filters;
448    int rate_i, rate_s;
449    uint32_t low, high, value;
450
451    if (bytestream2_get_bytes_left(&s->gbyte) < (stereo ? 20 : 13))
452        return AVERROR_INVALIDDATA;
453
454    rate_i = bytestream2_get_byte(&s->gbyte);
455    rate_s = bytestream2_get_byte(&s->gbyte);
456
457    if (rate_s != RATE_S)
458        return AVERROR_INVALIDDATA;
459
460    init_ptable(s->ptable, rate_i, rate_s);
461
462    for (int channel = 0; channel < stereo + 1; channel++) {
463        DSDfilters *sp = filters + channel;
464
465        sp->fltr1 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
466        sp->fltr2 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
467        sp->fltr3 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
468        sp->fltr4 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
469        sp->fltr5 = bytestream2_get_byte(&s->gbyte) << (PRECISION - 8);
470        sp->fltr6 = 0;
471        sp->factor = bytestream2_get_byte(&s->gbyte) & 0xff;
472        sp->factor |= (bytestream2_get_byte(&s->gbyte) << 8) & 0xff00;
473        sp->factor = (int32_t)((uint32_t)sp->factor << 16) >> 16;
474    }
475
476    value = bytestream2_get_be32(&s->gbyte);
477    high = 0xffffffff;
478    low = 0x0;
479
480    while (total_samples--) {
481        int bitcount = 8;
482
483        sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
484
485        if (stereo)
486            sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
487
488        while (bitcount--) {
489            int32_t *pp = s->ptable + ((sp[0].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
490            uint32_t split = low + ((high - low) >> 8) * (*pp >> 16);
491
492            if (value <= split) {
493                high = split;
494                *pp += (UP - *pp) >> DECAY;
495                sp[0].fltr0 = -1;
496            } else {
497                low = split + 1;
498                *pp += (DOWN - *pp) >> DECAY;
499                sp[0].fltr0 = 0;
500            }
501
502            if (DSD_BYTE_READY(high, low) && !bytestream2_get_bytes_left(&s->gbyte))
503                return AVERROR_INVALIDDATA;
504            while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
505                value = (value << 8) | bytestream2_get_byte(&s->gbyte);
506                high = (high << 8) | 0xff;
507                low <<= 8;
508            }
509
510            sp[0].value += sp[0].fltr6 * 8;
511            sp[0].byte = (sp[0].byte << 1) | (sp[0].fltr0 & 1);
512            sp[0].factor += (((sp[0].value ^ sp[0].fltr0) >> 31) | 1) &
513                ((sp[0].value ^ (sp[0].value - (sp[0].fltr6 * 16))) >> 31);
514            sp[0].fltr1 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr1) >> 6;
515            sp[0].fltr2 += ((sp[0].fltr0 & VALUE_ONE) - sp[0].fltr2) >> 4;
516            sp[0].fltr3 += (sp[0].fltr2 - sp[0].fltr3) >> 4;
517            sp[0].fltr4 += (sp[0].fltr3 - sp[0].fltr4) >> 4;
518            sp[0].value = (sp[0].fltr4 - sp[0].fltr5) >> 4;
519            sp[0].fltr5 += sp[0].value;
520            sp[0].fltr6 += (sp[0].value - sp[0].fltr6) >> 3;
521            sp[0].value = sp[0].fltr1 - sp[0].fltr5 + ((sp[0].fltr6 * sp[0].factor) >> 2);
522
523            if (!stereo)
524                continue;
525
526            pp = s->ptable + ((sp[1].value >> (PRECISION - PRECISION_USE)) & PTABLE_MASK);
527            split = low + ((high - low) >> 8) * (*pp >> 16);
528
529            if (value <= split) {
530                high = split;
531                *pp += (UP - *pp) >> DECAY;
532                sp[1].fltr0 = -1;
533            } else {
534                low = split + 1;
535                *pp += (DOWN - *pp) >> DECAY;
536                sp[1].fltr0 = 0;
537            }
538
539            if (DSD_BYTE_READY(high, low) && !bytestream2_get_bytes_left(&s->gbyte))
540                return AVERROR_INVALIDDATA;
541            while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
542                value = (value << 8) | bytestream2_get_byte(&s->gbyte);
543                high = (high << 8) | 0xff;
544                low <<= 8;
545            }
546
547            sp[1].value += sp[1].fltr6 * 8;
548            sp[1].byte = (sp[1].byte << 1) | (sp[1].fltr0 & 1);
549            sp[1].factor += (((sp[1].value ^ sp[1].fltr0) >> 31) | 1) &
550                ((sp[1].value ^ (sp[1].value - (sp[1].fltr6 * 16))) >> 31);
551            sp[1].fltr1 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr1) >> 6;
552            sp[1].fltr2 += ((sp[1].fltr0 & VALUE_ONE) - sp[1].fltr2) >> 4;
553            sp[1].fltr3 += (sp[1].fltr2 - sp[1].fltr3) >> 4;
554            sp[1].fltr4 += (sp[1].fltr3 - sp[1].fltr4) >> 4;
555            sp[1].value = (sp[1].fltr4 - sp[1].fltr5) >> 4;
556            sp[1].fltr5 += sp[1].value;
557            sp[1].fltr6 += (sp[1].value - sp[1].fltr6) >> 3;
558            sp[1].value = sp[1].fltr1 - sp[1].fltr5 + ((sp[1].fltr6 * sp[1].factor) >> 2);
559        }
560
561        checksum += (checksum << 1) + (*dst_l = sp[0].byte & 0xff);
562        sp[0].factor -= (sp[0].factor + 512) >> 10;
563        dst_l += 4;
564
565        if (stereo) {
566            checksum += (checksum << 1) + (*dst_r = filters[1].byte & 0xff);
567            filters[1].factor -= (filters[1].factor + 512) >> 10;
568            dst_r += 4;
569        }
570    }
571
572    if (wv_check_crc(s, checksum, 0)) {
573        if (s->avctx->err_recognition & AV_EF_CRCCHECK)
574            return AVERROR_INVALIDDATA;
575
576        memset(dst_left, 0x69, s->samples * 4);
577
578        if (dst_r)
579            memset(dst_right, 0x69, s->samples * 4);
580    }
581
582    return 0;
583}
584
585static int wv_unpack_dsd_fast(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
586{
587    uint8_t *dst_l = dst_left, *dst_r = dst_right;
588    uint8_t history_bits, max_probability;
589    int total_summed_probabilities  = 0;
590    int total_samples               = s->samples;
591    uint8_t *vlb                    = s->value_lookup_buffer;
592    int history_bins, p0, p1, chan;
593    uint32_t checksum               = 0xFFFFFFFF;
594    uint32_t low, high, value;
595
596    if (!bytestream2_get_bytes_left(&s->gbyte))
597        return AVERROR_INVALIDDATA;
598
599    history_bits = bytestream2_get_byte(&s->gbyte);
600
601    if (!bytestream2_get_bytes_left(&s->gbyte) || history_bits > MAX_HISTORY_BITS)
602        return AVERROR_INVALIDDATA;
603
604    history_bins = 1 << history_bits;
605    max_probability = bytestream2_get_byte(&s->gbyte);
606
607    if (max_probability < 0xff) {
608        uint8_t *outptr = (uint8_t *)s->probabilities;
609        uint8_t *outend = outptr + sizeof(*s->probabilities) * history_bins;
610
611        while (outptr < outend && bytestream2_get_bytes_left(&s->gbyte)) {
612            int code = bytestream2_get_byte(&s->gbyte);
613
614            if (code > max_probability) {
615                int zcount = code - max_probability;
616
617                while (outptr < outend && zcount--)
618                    *outptr++ = 0;
619            } else if (code) {
620                *outptr++ = code;
621            }
622            else {
623                break;
624            }
625        }
626
627        if (outptr < outend ||
628            (bytestream2_get_bytes_left(&s->gbyte) && bytestream2_get_byte(&s->gbyte)))
629                return AVERROR_INVALIDDATA;
630    } else if (bytestream2_get_bytes_left(&s->gbyte) > (int)sizeof(*s->probabilities) * history_bins) {
631        bytestream2_get_buffer(&s->gbyte, (uint8_t *)s->probabilities,
632            sizeof(*s->probabilities) * history_bins);
633    } else {
634        return AVERROR_INVALIDDATA;
635    }
636
637    for (p0 = 0; p0 < history_bins; p0++) {
638        int32_t sum_values = 0;
639
640        for (int i = 0; i < 256; i++)
641            s->summed_probabilities[p0][i] = sum_values += s->probabilities[p0][i];
642
643        if (sum_values) {
644            total_summed_probabilities += sum_values;
645
646            if (total_summed_probabilities > history_bins * MAX_BIN_BYTES)
647                return AVERROR_INVALIDDATA;
648
649            s->value_lookup[p0] = vlb;
650
651            for (int i = 0; i < 256; i++) {
652                int c = s->probabilities[p0][i];
653
654                while (c--)
655                    *vlb++ = i;
656            }
657        }
658    }
659
660    if (bytestream2_get_bytes_left(&s->gbyte) < 4)
661        return AVERROR_INVALIDDATA;
662
663    chan = p0 = p1 = 0;
664    low = 0; high = 0xffffffff;
665    value = bytestream2_get_be32(&s->gbyte);
666
667    if (dst_r)
668        total_samples *= 2;
669
670    while (total_samples--) {
671        unsigned int mult, index, code;
672
673        if (!s->summed_probabilities[p0][255])
674            return AVERROR_INVALIDDATA;
675
676        mult = (high - low) / s->summed_probabilities[p0][255];
677
678        if (!mult) {
679            if (bytestream2_get_bytes_left(&s->gbyte) >= 4)
680                value = bytestream2_get_be32(&s->gbyte);
681
682            low = 0;
683            high = 0xffffffff;
684            mult = high / s->summed_probabilities[p0][255];
685
686            if (!mult)
687                return AVERROR_INVALIDDATA;
688        }
689
690        index = (value - low) / mult;
691
692        if (index >= s->summed_probabilities[p0][255])
693            return AVERROR_INVALIDDATA;
694
695        if (!dst_r) {
696            if ((*dst_l = code = s->value_lookup[p0][index]))
697                low += s->summed_probabilities[p0][code-1] * mult;
698
699            dst_l += 4;
700        } else {
701            if ((code = s->value_lookup[p0][index]))
702                low += s->summed_probabilities[p0][code-1] * mult;
703
704            if (chan) {
705                *dst_r = code;
706                dst_r += 4;
707            }
708            else {
709                *dst_l = code;
710                dst_l += 4;
711            }
712
713            chan ^= 1;
714        }
715
716        high = low + s->probabilities[p0][code] * mult - 1;
717        checksum += (checksum << 1) + code;
718
719        if (!dst_r) {
720            p0 = code & (history_bins-1);
721        } else {
722            p0 = p1;
723            p1 = code & (history_bins-1);
724        }
725
726        while (DSD_BYTE_READY(high, low) && bytestream2_get_bytes_left(&s->gbyte)) {
727            value = (value << 8) | bytestream2_get_byte(&s->gbyte);
728            high = (high << 8) | 0xff;
729            low <<= 8;
730        }
731    }
732
733    if (wv_check_crc(s, checksum, 0)) {
734        if (s->avctx->err_recognition & AV_EF_CRCCHECK)
735            return AVERROR_INVALIDDATA;
736
737        memset(dst_left, 0x69, s->samples * 4);
738
739        if (dst_r)
740            memset(dst_right, 0x69, s->samples * 4);
741    }
742
743    return 0;
744}
745
746static int wv_unpack_dsd_copy(WavpackFrameContext *s, uint8_t *dst_left, uint8_t *dst_right)
747{
748    uint8_t *dst_l = dst_left, *dst_r = dst_right;
749    int total_samples           = s->samples;
750    uint32_t checksum           = 0xFFFFFFFF;
751
752    if (bytestream2_get_bytes_left(&s->gbyte) != total_samples * (dst_r ? 2 : 1))
753        return AVERROR_INVALIDDATA;
754
755    while (total_samples--) {
756        checksum += (checksum << 1) + (*dst_l = bytestream2_get_byte(&s->gbyte));
757        dst_l += 4;
758
759        if (dst_r) {
760            checksum += (checksum << 1) + (*dst_r = bytestream2_get_byte(&s->gbyte));
761            dst_r += 4;
762        }
763    }
764
765    if (wv_check_crc(s, checksum, 0)) {
766        if (s->avctx->err_recognition & AV_EF_CRCCHECK)
767            return AVERROR_INVALIDDATA;
768
769        memset(dst_left, 0x69, s->samples * 4);
770
771        if (dst_r)
772            memset(dst_right, 0x69, s->samples * 4);
773    }
774
775    return 0;
776}
777
778static inline int wv_unpack_stereo(WavpackFrameContext *s, GetBitContext *gb,
779                                   void *dst_l, void *dst_r, const int type)
780{
781    int i, j, count = 0;
782    int last, t;
783    int A, B, L, L2, R, R2;
784    int pos                 = 0;
785    uint32_t crc            = 0xFFFFFFFF;
786    uint32_t crc_extra_bits = 0xFFFFFFFF;
787    int16_t *dst16_l        = dst_l;
788    int16_t *dst16_r        = dst_r;
789    int32_t *dst32_l        = dst_l;
790    int32_t *dst32_r        = dst_r;
791    float *dstfl_l          = dst_l;
792    float *dstfl_r          = dst_r;
793
794    s->one = s->zero = s->zeroes = 0;
795    do {
796        L = wv_get_value(s, gb, 0, &last);
797        if (last)
798            break;
799        R = wv_get_value(s, gb, 1, &last);
800        if (last)
801            break;
802        for (i = 0; i < s->terms; i++) {
803            t = s->decorr[i].value;
804            if (t > 0) {
805                if (t > 8) {
806                    if (t & 1) {
807                        A = 2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
808                        B = 2U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1];
809                    } else {
810                        A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
811                        B = (int)(3U * s->decorr[i].samplesB[0] - s->decorr[i].samplesB[1]) >> 1;
812                    }
813                    s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
814                    s->decorr[i].samplesB[1] = s->decorr[i].samplesB[0];
815                    j                        = 0;
816                } else {
817                    A = s->decorr[i].samplesA[pos];
818                    B = s->decorr[i].samplesB[pos];
819                    j = (pos + t) & 7;
820                }
821                if (type != AV_SAMPLE_FMT_S16P) {
822                    L2 = L + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
823                    R2 = R + ((s->decorr[i].weightB * (int64_t)B + 512) >> 10);
824                } else {
825                    L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
826                    R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)B + 512) >> 10);
827                }
828                if (A && L)
829                    s->decorr[i].weightA -= ((((L ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
830                if (B && R)
831                    s->decorr[i].weightB -= ((((R ^ B) >> 30) & 2) - 1) * s->decorr[i].delta;
832                s->decorr[i].samplesA[j] = L = L2;
833                s->decorr[i].samplesB[j] = R = R2;
834            } else if (t == -1) {
835                if (type != AV_SAMPLE_FMT_S16P)
836                    L2 = L + ((s->decorr[i].weightA * (int64_t)s->decorr[i].samplesA[0] + 512) >> 10);
837                else
838                    L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)s->decorr[i].samplesA[0] + 512) >> 10);
839                UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, s->decorr[i].samplesA[0], L);
840                L = L2;
841                if (type != AV_SAMPLE_FMT_S16P)
842                    R2 = R + ((s->decorr[i].weightB * (int64_t)L2 + 512) >> 10);
843                else
844                    R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)L2 + 512) >> 10);
845                UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, L2, R);
846                R                        = R2;
847                s->decorr[i].samplesA[0] = R;
848            } else {
849                if (type != AV_SAMPLE_FMT_S16P)
850                    R2 = R + ((s->decorr[i].weightB * (int64_t)s->decorr[i].samplesB[0] + 512) >> 10);
851                else
852                    R2 = R + (unsigned)((int)(s->decorr[i].weightB * (unsigned)s->decorr[i].samplesB[0] + 512) >> 10);
853                UPDATE_WEIGHT_CLIP(s->decorr[i].weightB, s->decorr[i].delta, s->decorr[i].samplesB[0], R);
854                R = R2;
855
856                if (t == -3) {
857                    R2                       = s->decorr[i].samplesA[0];
858                    s->decorr[i].samplesA[0] = R;
859                }
860
861                if (type != AV_SAMPLE_FMT_S16P)
862                    L2 = L + ((s->decorr[i].weightA * (int64_t)R2 + 512) >> 10);
863                else
864                    L2 = L + (unsigned)((int)(s->decorr[i].weightA * (unsigned)R2 + 512) >> 10);
865                UPDATE_WEIGHT_CLIP(s->decorr[i].weightA, s->decorr[i].delta, R2, L);
866                L                        = L2;
867                s->decorr[i].samplesB[0] = L;
868            }
869        }
870
871        if (type == AV_SAMPLE_FMT_S16P) {
872            if (FFABS((int64_t)L) + FFABS((int64_t)R) > (1<<19)) {
873                av_log(s->avctx, AV_LOG_ERROR, "sample %d %d too large\n", L, R);
874                return AVERROR_INVALIDDATA;
875            }
876        }
877
878        pos = (pos + 1) & 7;
879        if (s->joint)
880            L += (unsigned)(R -= (unsigned)(L >> 1));
881        crc = (crc * 3 + L) * 3 + R;
882
883        if (type == AV_SAMPLE_FMT_FLTP) {
884            *dstfl_l++ = wv_get_value_float(s, &crc_extra_bits, L);
885            *dstfl_r++ = wv_get_value_float(s, &crc_extra_bits, R);
886        } else if (type == AV_SAMPLE_FMT_S32P) {
887            *dst32_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
888            *dst32_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
889        } else {
890            *dst16_l++ = wv_get_value_integer(s, &crc_extra_bits, L);
891            *dst16_r++ = wv_get_value_integer(s, &crc_extra_bits, R);
892        }
893        count++;
894    } while (!last && count < s->samples);
895
896    if (last && count < s->samples) {
897        int size = av_get_bytes_per_sample(type);
898        memset((uint8_t*)dst_l + count*size, 0, (s->samples-count)*size);
899        memset((uint8_t*)dst_r + count*size, 0, (s->samples-count)*size);
900    }
901
902    if ((s->avctx->err_recognition & AV_EF_CRCCHECK) &&
903        wv_check_crc(s, crc, crc_extra_bits))
904        return AVERROR_INVALIDDATA;
905
906    return 0;
907}
908
909static inline int wv_unpack_mono(WavpackFrameContext *s, GetBitContext *gb,
910                                 void *dst, const int type)
911{
912    int i, j, count = 0;
913    int last, t;
914    int A, S, T;
915    int pos                  = 0;
916    uint32_t crc             = 0xFFFFFFFF;
917    uint32_t crc_extra_bits  = 0xFFFFFFFF;
918    int16_t *dst16           = dst;
919    int32_t *dst32           = dst;
920    float *dstfl             = dst;
921
922    s->one = s->zero = s->zeroes = 0;
923    do {
924        T = wv_get_value(s, gb, 0, &last);
925        S = 0;
926        if (last)
927            break;
928        for (i = 0; i < s->terms; i++) {
929            t = s->decorr[i].value;
930            if (t > 8) {
931                if (t & 1)
932                    A =  2U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1];
933                else
934                    A = (int)(3U * s->decorr[i].samplesA[0] - s->decorr[i].samplesA[1]) >> 1;
935                s->decorr[i].samplesA[1] = s->decorr[i].samplesA[0];
936                j                        = 0;
937            } else {
938                A = s->decorr[i].samplesA[pos];
939                j = (pos + t) & 7;
940            }
941            if (type != AV_SAMPLE_FMT_S16P)
942                S = T + ((s->decorr[i].weightA * (int64_t)A + 512) >> 10);
943            else
944                S = T + (unsigned)((int)(s->decorr[i].weightA * (unsigned)A + 512) >> 10);
945            if (A && T)
946                s->decorr[i].weightA -= ((((T ^ A) >> 30) & 2) - 1) * s->decorr[i].delta;
947            s->decorr[i].samplesA[j] = T = S;
948        }
949        pos = (pos + 1) & 7;
950        crc = crc * 3 + S;
951
952        if (type == AV_SAMPLE_FMT_FLTP) {
953            *dstfl++ = wv_get_value_float(s, &crc_extra_bits, S);
954        } else if (type == AV_SAMPLE_FMT_S32P) {
955            *dst32++ = wv_get_value_integer(s, &crc_extra_bits, S);
956        } else {
957            *dst16++ = wv_get_value_integer(s, &crc_extra_bits, S);
958        }
959        count++;
960    } while (!last && count < s->samples);
961
962    if (last && count < s->samples) {
963        int size = av_get_bytes_per_sample(type);
964        memset((uint8_t*)dst + count*size, 0, (s->samples-count)*size);
965    }
966
967    if (s->avctx->err_recognition & AV_EF_CRCCHECK) {
968        int ret = wv_check_crc(s, crc, crc_extra_bits);
969        if (ret < 0 && s->avctx->err_recognition & AV_EF_EXPLODE)
970            return ret;
971    }
972
973    return 0;
974}
975
976static av_cold int wv_alloc_frame_context(WavpackContext *c)
977{
978    if (c->fdec_num == WV_MAX_FRAME_DECODERS)
979        return -1;
980
981    c->fdec[c->fdec_num] = av_mallocz(sizeof(**c->fdec));
982    if (!c->fdec[c->fdec_num])
983        return -1;
984    c->fdec_num++;
985    c->fdec[c->fdec_num - 1]->avctx = c->avctx;
986
987    return 0;
988}
989
990static int wv_dsd_reset(WavpackContext *s, int channels)
991{
992    int i;
993
994    s->dsdctx = NULL;
995    s->dsd_channels = 0;
996    av_buffer_unref(&s->dsd_ref);
997
998    if (!channels)
999        return 0;
1000
1001    if (channels > INT_MAX / sizeof(*s->dsdctx))
1002        return AVERROR(EINVAL);
1003
1004    s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
1005    if (!s->dsd_ref)
1006        return AVERROR(ENOMEM);
1007    s->dsdctx = (DSDContext*)s->dsd_ref->data;
1008    s->dsd_channels = channels;
1009
1010    for (i = 0; i < channels; i++)
1011        memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
1012
1013    return 0;
1014}
1015
1016#if HAVE_THREADS
1017static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
1018{
1019    WavpackContext *fsrc = src->priv_data;
1020    WavpackContext *fdst = dst->priv_data;
1021    int ret;
1022
1023    if (dst == src)
1024        return 0;
1025
1026    ff_thread_release_ext_buffer(dst, &fdst->curr_frame);
1027    if (fsrc->curr_frame.f->data[0]) {
1028        if ((ret = ff_thread_ref_frame(&fdst->curr_frame, &fsrc->curr_frame)) < 0)
1029            return ret;
1030    }
1031
1032    fdst->dsdctx = NULL;
1033    fdst->dsd_channels = 0;
1034    ret = av_buffer_replace(&fdst->dsd_ref, fsrc->dsd_ref);
1035    if (ret < 0)
1036        return ret;
1037    if (fsrc->dsd_ref) {
1038        fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
1039        fdst->dsd_channels = fsrc->dsd_channels;
1040    }
1041
1042    return 0;
1043}
1044#endif
1045
1046static av_cold int wavpack_decode_init(AVCodecContext *avctx)
1047{
1048    WavpackContext *s = avctx->priv_data;
1049
1050    s->avctx = avctx;
1051
1052    s->fdec_num = 0;
1053
1054    s->curr_frame.f = av_frame_alloc();
1055    s->prev_frame.f = av_frame_alloc();
1056
1057    if (!s->curr_frame.f || !s->prev_frame.f)
1058        return AVERROR(ENOMEM);
1059
1060    ff_init_dsd_data();
1061
1062    return 0;
1063}
1064
1065static av_cold int wavpack_decode_end(AVCodecContext *avctx)
1066{
1067    WavpackContext *s = avctx->priv_data;
1068
1069    for (int i = 0; i < s->fdec_num; i++)
1070        av_freep(&s->fdec[i]);
1071    s->fdec_num = 0;
1072
1073    ff_thread_release_ext_buffer(avctx, &s->curr_frame);
1074    av_frame_free(&s->curr_frame.f);
1075
1076    ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1077    av_frame_free(&s->prev_frame.f);
1078
1079    av_buffer_unref(&s->dsd_ref);
1080
1081    return 0;
1082}
1083
1084static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
1085                                const uint8_t *buf, int buf_size)
1086{
1087    WavpackContext *wc = avctx->priv_data;
1088    WavpackFrameContext *s;
1089    GetByteContext gb;
1090    enum AVSampleFormat sample_fmt;
1091    void *samples_l = NULL, *samples_r = NULL;
1092    int ret;
1093    int got_terms   = 0, got_weights = 0, got_samples = 0,
1094        got_entropy = 0, got_pcm     = 0, got_float   = 0, got_hybrid = 0;
1095    int got_dsd = 0;
1096    int i, j, id, size, ssize, weights, t;
1097    int bpp, chan = 0, orig_bpp, sample_rate = 0, rate_x = 1, dsd_mode = 0;
1098    int multiblock;
1099    uint64_t chmask = 0;
1100
1101    if (block_no >= wc->fdec_num && wv_alloc_frame_context(wc) < 0) {
1102        av_log(avctx, AV_LOG_ERROR, "Error creating frame decode context\n");
1103        return AVERROR_INVALIDDATA;
1104    }
1105
1106    s = wc->fdec[block_no];
1107    if (!s) {
1108        av_log(avctx, AV_LOG_ERROR, "Context for block %d is not present\n",
1109               block_no);
1110        return AVERROR_INVALIDDATA;
1111    }
1112
1113    memset(s->decorr, 0, MAX_TERMS * sizeof(Decorr));
1114    memset(s->ch, 0, sizeof(s->ch));
1115    s->extra_bits     = 0;
1116    s->and            = s->or = s->shift = 0;
1117    s->got_extra_bits = 0;
1118
1119    bytestream2_init(&gb, buf, buf_size);
1120
1121    s->samples = bytestream2_get_le32(&gb);
1122    if (s->samples != wc->samples) {
1123        av_log(avctx, AV_LOG_ERROR, "Mismatching number of samples in "
1124               "a sequence: %d and %d\n", wc->samples, s->samples);
1125        return AVERROR_INVALIDDATA;
1126    }
1127    s->frame_flags = bytestream2_get_le32(&gb);
1128
1129    if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
1130        sample_fmt = AV_SAMPLE_FMT_FLTP;
1131    else if ((s->frame_flags & 0x03) <= 1)
1132        sample_fmt = AV_SAMPLE_FMT_S16P;
1133    else
1134        sample_fmt          = AV_SAMPLE_FMT_S32P;
1135
1136    if (wc->ch_offset && avctx->sample_fmt != sample_fmt)
1137        return AVERROR_INVALIDDATA;
1138
1139    bpp            = av_get_bytes_per_sample(sample_fmt);
1140    orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
1141    multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
1142
1143    s->stereo         = !(s->frame_flags & WV_MONO);
1144    s->stereo_in      =  (s->frame_flags & WV_FALSE_STEREO) ? 0 : s->stereo;
1145    s->joint          =   s->frame_flags & WV_JOINT_STEREO;
1146    s->hybrid         =   s->frame_flags & WV_HYBRID_MODE;
1147    s->hybrid_bitrate =   s->frame_flags & WV_HYBRID_BITRATE;
1148    s->post_shift     = bpp * 8 - orig_bpp + ((s->frame_flags >> 13) & 0x1f);
1149    if (s->post_shift < 0 || s->post_shift > 31) {
1150        return AVERROR_INVALIDDATA;
1151    }
1152    s->hybrid_maxclip =  ((1LL << (orig_bpp - 1)) - 1);
1153    s->hybrid_minclip = ((-1UL << (orig_bpp - 1)));
1154    s->CRC            = bytestream2_get_le32(&gb);
1155
1156    // parse metadata blocks
1157    while (bytestream2_get_bytes_left(&gb)) {
1158        id   = bytestream2_get_byte(&gb);
1159        size = bytestream2_get_byte(&gb);
1160        if (id & WP_IDF_LONG)
1161            size |= (bytestream2_get_le16u(&gb)) << 8;
1162        size <<= 1; // size is specified in words
1163        ssize  = size;
1164        if (id & WP_IDF_ODD)
1165            size--;
1166        if (size < 0) {
1167            av_log(avctx, AV_LOG_ERROR,
1168                   "Got incorrect block %02X with size %i\n", id, size);
1169            break;
1170        }
1171        if (bytestream2_get_bytes_left(&gb) < ssize) {
1172            av_log(avctx, AV_LOG_ERROR,
1173                   "Block size %i is out of bounds\n", size);
1174            break;
1175        }
1176        switch (id & WP_IDF_MASK) {
1177        case WP_ID_DECTERMS:
1178            if (size > MAX_TERMS) {
1179                av_log(avctx, AV_LOG_ERROR, "Too many decorrelation terms\n");
1180                s->terms = 0;
1181                bytestream2_skip(&gb, ssize);
1182                continue;
1183            }
1184            s->terms = size;
1185            for (i = 0; i < s->terms; i++) {
1186                uint8_t val = bytestream2_get_byte(&gb);
1187                s->decorr[s->terms - i - 1].value = (val & 0x1F) - 5;
1188                s->decorr[s->terms - i - 1].delta =  val >> 5;
1189            }
1190            got_terms = 1;
1191            break;
1192        case WP_ID_DECWEIGHTS:
1193            if (!got_terms) {
1194                av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1195                continue;
1196            }
1197            weights = size >> s->stereo_in;
1198            if (weights > MAX_TERMS || weights > s->terms) {
1199                av_log(avctx, AV_LOG_ERROR, "Too many decorrelation weights\n");
1200                bytestream2_skip(&gb, ssize);
1201                continue;
1202            }
1203            for (i = 0; i < weights; i++) {
1204                t = (int8_t)bytestream2_get_byte(&gb);
1205                s->decorr[s->terms - i - 1].weightA = t * (1 << 3);
1206                if (s->decorr[s->terms - i - 1].weightA > 0)
1207                    s->decorr[s->terms - i - 1].weightA +=
1208                        (s->decorr[s->terms - i - 1].weightA + 64) >> 7;
1209                if (s->stereo_in) {
1210                    t = (int8_t)bytestream2_get_byte(&gb);
1211                    s->decorr[s->terms - i - 1].weightB = t * (1 << 3);
1212                    if (s->decorr[s->terms - i - 1].weightB > 0)
1213                        s->decorr[s->terms - i - 1].weightB +=
1214                            (s->decorr[s->terms - i - 1].weightB + 64) >> 7;
1215                }
1216            }
1217            got_weights = 1;
1218            break;
1219        case WP_ID_DECSAMPLES:
1220            if (!got_terms) {
1221                av_log(avctx, AV_LOG_ERROR, "No decorrelation terms met\n");
1222                continue;
1223            }
1224            t = 0;
1225            for (i = s->terms - 1; (i >= 0) && (t < size); i--) {
1226                if (s->decorr[i].value > 8) {
1227                    s->decorr[i].samplesA[0] =
1228                        wp_exp2(bytestream2_get_le16(&gb));
1229                    s->decorr[i].samplesA[1] =
1230                        wp_exp2(bytestream2_get_le16(&gb));
1231
1232                    if (s->stereo_in) {
1233                        s->decorr[i].samplesB[0] =
1234                            wp_exp2(bytestream2_get_le16(&gb));
1235                        s->decorr[i].samplesB[1] =
1236                            wp_exp2(bytestream2_get_le16(&gb));
1237                        t                       += 4;
1238                    }
1239                    t += 4;
1240                } else if (s->decorr[i].value < 0) {
1241                    s->decorr[i].samplesA[0] =
1242                        wp_exp2(bytestream2_get_le16(&gb));
1243                    s->decorr[i].samplesB[0] =
1244                        wp_exp2(bytestream2_get_le16(&gb));
1245                    t                       += 4;
1246                } else {
1247                    for (j = 0; j < s->decorr[i].value; j++) {
1248                        s->decorr[i].samplesA[j] =
1249                            wp_exp2(bytestream2_get_le16(&gb));
1250                        if (s->stereo_in) {
1251                            s->decorr[i].samplesB[j] =
1252                                wp_exp2(bytestream2_get_le16(&gb));
1253                        }
1254                    }
1255                    t += s->decorr[i].value * 2 * (s->stereo_in + 1);
1256                }
1257            }
1258            got_samples = 1;
1259            break;
1260        case WP_ID_ENTROPY:
1261            if (size != 6 * (s->stereo_in + 1)) {
1262                av_log(avctx, AV_LOG_ERROR,
1263                       "Entropy vars size should be %i, got %i.\n",
1264                       6 * (s->stereo_in + 1), size);
1265                bytestream2_skip(&gb, ssize);
1266                continue;
1267            }
1268            for (j = 0; j <= s->stereo_in; j++)
1269                for (i = 0; i < 3; i++) {
1270                    s->ch[j].median[i] = wp_exp2(bytestream2_get_le16(&gb));
1271                }
1272            got_entropy = 1;
1273            break;
1274        case WP_ID_HYBRID:
1275            if (s->hybrid_bitrate) {
1276                for (i = 0; i <= s->stereo_in; i++) {
1277                    s->ch[i].slow_level = wp_exp2(bytestream2_get_le16(&gb));
1278                    size               -= 2;
1279                }
1280            }
1281            for (i = 0; i < (s->stereo_in + 1); i++) {
1282                s->ch[i].bitrate_acc = bytestream2_get_le16(&gb) << 16;
1283                size                -= 2;
1284            }
1285            if (size > 0) {
1286                for (i = 0; i < (s->stereo_in + 1); i++) {
1287                    s->ch[i].bitrate_delta =
1288                        wp_exp2((int16_t)bytestream2_get_le16(&gb));
1289                }
1290            } else {
1291                for (i = 0; i < (s->stereo_in + 1); i++)
1292                    s->ch[i].bitrate_delta = 0;
1293            }
1294            got_hybrid = 1;
1295            break;
1296        case WP_ID_INT32INFO: {
1297            uint8_t val[4];
1298            if (size != 4) {
1299                av_log(avctx, AV_LOG_ERROR,
1300                       "Invalid INT32INFO, size = %i\n",
1301                       size);
1302                bytestream2_skip(&gb, ssize - 4);
1303                continue;
1304            }
1305            bytestream2_get_buffer(&gb, val, 4);
1306            if (val[0] > 30) {
1307                av_log(avctx, AV_LOG_ERROR,
1308                       "Invalid INT32INFO, extra_bits = %d (> 30)\n", val[0]);
1309                continue;
1310            } else {
1311                s->extra_bits = val[0];
1312            }
1313            if (val[1])
1314                s->shift = val[1];
1315            if (val[2]) {
1316                s->and   = s->or = 1;
1317                s->shift = val[2];
1318            }
1319            if (val[3]) {
1320                s->and   = 1;
1321                s->shift = val[3];
1322            }
1323            if (s->shift > 31) {
1324                av_log(avctx, AV_LOG_ERROR,
1325                       "Invalid INT32INFO, shift = %d (> 31)\n", s->shift);
1326                s->and = s->or = s->shift = 0;
1327                continue;
1328            }
1329            /* original WavPack decoder forces 32-bit lossy sound to be treated
1330             * as 24-bit one in order to have proper clipping */
1331            if (s->hybrid && bpp == 4 && s->post_shift < 8 && s->shift > 8) {
1332                s->post_shift      += 8;
1333                s->shift           -= 8;
1334                s->hybrid_maxclip >>= 8;
1335                s->hybrid_minclip >>= 8;
1336            }
1337            break;
1338        }
1339        case WP_ID_FLOATINFO:
1340            if (size != 4) {
1341                av_log(avctx, AV_LOG_ERROR,
1342                       "Invalid FLOATINFO, size = %i\n", size);
1343                bytestream2_skip(&gb, ssize);
1344                continue;
1345            }
1346            s->float_flag    = bytestream2_get_byte(&gb);
1347            s->float_shift   = bytestream2_get_byte(&gb);
1348            s->float_max_exp = bytestream2_get_byte(&gb);
1349            if (s->float_shift > 31) {
1350                av_log(avctx, AV_LOG_ERROR,
1351                       "Invalid FLOATINFO, shift = %d (> 31)\n", s->float_shift);
1352                s->float_shift = 0;
1353                continue;
1354            }
1355            got_float        = 1;
1356            bytestream2_skip(&gb, 1);
1357            break;
1358        case WP_ID_DATA:
1359            if ((ret = init_get_bits8(&s->gb, gb.buffer, size)) < 0)
1360                return ret;
1361            bytestream2_skip(&gb, size);
1362            got_pcm      = 1;
1363            break;
1364        case WP_ID_DSD_DATA:
1365            if (size < 2) {
1366                av_log(avctx, AV_LOG_ERROR, "Invalid DSD_DATA, size = %i\n",
1367                       size);
1368                bytestream2_skip(&gb, ssize);
1369                continue;
1370            }
1371            rate_x = bytestream2_get_byte(&gb);
1372            if (rate_x > 30)
1373                return AVERROR_INVALIDDATA;
1374            rate_x = 1 << rate_x;
1375            dsd_mode = bytestream2_get_byte(&gb);
1376            if (dsd_mode && dsd_mode != 1 && dsd_mode != 3) {
1377                av_log(avctx, AV_LOG_ERROR, "Invalid DSD encoding mode: %d\n",
1378                    dsd_mode);
1379                return AVERROR_INVALIDDATA;
1380            }
1381            bytestream2_init(&s->gbyte, gb.buffer, size-2);
1382            bytestream2_skip(&gb, size-2);
1383            got_dsd      = 1;
1384            break;
1385        case WP_ID_EXTRABITS:
1386            if (size <= 4) {
1387                av_log(avctx, AV_LOG_ERROR, "Invalid EXTRABITS, size = %i\n",
1388                       size);
1389                bytestream2_skip(&gb, size);
1390                continue;
1391            }
1392            if ((ret = init_get_bits8(&s->gb_extra_bits, gb.buffer, size)) < 0)
1393                return ret;
1394            s->crc_extra_bits  = get_bits_long(&s->gb_extra_bits, 32);
1395            bytestream2_skip(&gb, size);
1396            s->got_extra_bits  = 1;
1397            break;
1398        case WP_ID_CHANINFO:
1399            if (size <= 1) {
1400                av_log(avctx, AV_LOG_ERROR,
1401                       "Insufficient channel information\n");
1402                return AVERROR_INVALIDDATA;
1403            }
1404            chan = bytestream2_get_byte(&gb);
1405            switch (size - 2) {
1406            case 0:
1407                chmask = bytestream2_get_byte(&gb);
1408                break;
1409            case 1:
1410                chmask = bytestream2_get_le16(&gb);
1411                break;
1412            case 2:
1413                chmask = bytestream2_get_le24(&gb);
1414                break;
1415            case 3:
1416                chmask = bytestream2_get_le32(&gb);
1417                break;
1418            case 4:
1419                size = bytestream2_get_byte(&gb);
1420                chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1421                chan  += 1;
1422                if (avctx->ch_layout.nb_channels != chan)
1423                    av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1424                           " instead of %i.\n", chan, avctx->ch_layout.nb_channels);
1425                chmask = bytestream2_get_le24(&gb);
1426                break;
1427            case 5:
1428                size = bytestream2_get_byte(&gb);
1429                chan  |= (bytestream2_get_byte(&gb) & 0xF) << 8;
1430                chan  += 1;
1431                if (avctx->ch_layout.nb_channels != chan)
1432                    av_log(avctx, AV_LOG_WARNING, "%i channels signalled"
1433                           " instead of %i.\n", chan, avctx->ch_layout.nb_channels);
1434                chmask = bytestream2_get_le32(&gb);
1435                break;
1436            default:
1437                av_log(avctx, AV_LOG_ERROR, "Invalid channel info size %d\n",
1438                       size);
1439            }
1440            break;
1441        case WP_ID_SAMPLE_RATE:
1442            if (size != 3) {
1443                av_log(avctx, AV_LOG_ERROR, "Invalid custom sample rate.\n");
1444                return AVERROR_INVALIDDATA;
1445            }
1446            sample_rate = bytestream2_get_le24(&gb);
1447            break;
1448        default:
1449            bytestream2_skip(&gb, size);
1450        }
1451        if (id & WP_IDF_ODD)
1452            bytestream2_skip(&gb, 1);
1453    }
1454
1455    if (got_pcm) {
1456        if (!got_terms) {
1457            av_log(avctx, AV_LOG_ERROR, "No block with decorrelation terms\n");
1458            return AVERROR_INVALIDDATA;
1459        }
1460        if (!got_weights) {
1461            av_log(avctx, AV_LOG_ERROR, "No block with decorrelation weights\n");
1462            return AVERROR_INVALIDDATA;
1463        }
1464        if (!got_samples) {
1465            av_log(avctx, AV_LOG_ERROR, "No block with decorrelation samples\n");
1466            return AVERROR_INVALIDDATA;
1467        }
1468        if (!got_entropy) {
1469            av_log(avctx, AV_LOG_ERROR, "No block with entropy info\n");
1470            return AVERROR_INVALIDDATA;
1471        }
1472        if (s->hybrid && !got_hybrid) {
1473            av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
1474            return AVERROR_INVALIDDATA;
1475        }
1476        if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
1477            av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
1478            return AVERROR_INVALIDDATA;
1479        }
1480        if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
1481            const int size   = get_bits_left(&s->gb_extra_bits);
1482            const int wanted = s->samples * s->extra_bits << s->stereo_in;
1483            if (size < wanted) {
1484                av_log(avctx, AV_LOG_ERROR, "Too small EXTRABITS\n");
1485                s->got_extra_bits = 0;
1486            }
1487        }
1488    }
1489
1490    if (!got_pcm && !got_dsd) {
1491        av_log(avctx, AV_LOG_ERROR, "Packed samples not found\n");
1492        return AVERROR_INVALIDDATA;
1493    }
1494
1495    if ((got_pcm && wc->modulation != MODULATION_PCM) ||
1496        (got_dsd && wc->modulation != MODULATION_DSD)) {
1497            av_log(avctx, AV_LOG_ERROR, "Invalid PCM/DSD mix encountered\n");
1498            return AVERROR_INVALIDDATA;
1499    }
1500
1501    if (!wc->ch_offset) {
1502        AVChannelLayout new_ch_layout = { 0 };
1503        int new_samplerate;
1504        int sr = (s->frame_flags >> 23) & 0xf;
1505        if (sr == 0xf) {
1506            if (!sample_rate) {
1507                av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
1508                return AVERROR_INVALIDDATA;
1509            }
1510            new_samplerate = sample_rate;
1511        } else
1512            new_samplerate = wv_rates[sr];
1513
1514        if (new_samplerate * (uint64_t)rate_x > INT_MAX)
1515            return AVERROR_INVALIDDATA;
1516        new_samplerate *= rate_x;
1517
1518        if (multiblock) {
1519            if (chmask) {
1520                av_channel_layout_from_mask(&new_ch_layout, chmask);
1521                if (chan && new_ch_layout.nb_channels != chan) {
1522                    av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
1523                    return AVERROR_INVALIDDATA;
1524                }
1525            } else {
1526                ret = av_channel_layout_copy(&new_ch_layout, &avctx->ch_layout);
1527                if (ret < 0) {
1528                    av_log(avctx, AV_LOG_ERROR, "Error copying channel layout\n");
1529                    return ret;
1530                }
1531            }
1532        } else {
1533            av_channel_layout_default(&new_ch_layout, s->stereo + 1);
1534        }
1535
1536        /* clear DSD state if stream properties change */
1537        if (new_ch_layout.nb_channels != wc->dsd_channels ||
1538            av_channel_layout_compare(&new_ch_layout, &avctx->ch_layout) ||
1539            new_samplerate != avctx->sample_rate    ||
1540            !!got_dsd      != !!wc->dsdctx) {
1541            ret = wv_dsd_reset(wc, got_dsd ? new_ch_layout.nb_channels : 0);
1542            if (ret < 0) {
1543                av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
1544                return ret;
1545            }
1546            ff_thread_release_ext_buffer(avctx, &wc->curr_frame);
1547        }
1548        av_channel_layout_copy(&avctx->ch_layout, &new_ch_layout);
1549        avctx->sample_rate         = new_samplerate;
1550        avctx->sample_fmt          = sample_fmt;
1551        avctx->bits_per_raw_sample = orig_bpp;
1552
1553        ff_thread_release_ext_buffer(avctx, &wc->prev_frame);
1554        FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
1555
1556        /* get output buffer */
1557        wc->curr_frame.f->nb_samples = s->samples;
1558        ret = ff_thread_get_ext_buffer(avctx, &wc->curr_frame,
1559                                       AV_GET_BUFFER_FLAG_REF);
1560        if (ret < 0)
1561            return ret;
1562
1563        wc->frame = wc->curr_frame.f;
1564        ff_thread_finish_setup(avctx);
1565    }
1566
1567    if (wc->ch_offset + s->stereo >= avctx->ch_layout.nb_channels) {
1568        av_log(avctx, AV_LOG_WARNING, "Too many channels coded in a packet.\n");
1569        return ((avctx->err_recognition & AV_EF_EXPLODE) || !wc->ch_offset) ? AVERROR_INVALIDDATA : 0;
1570    }
1571
1572    samples_l = wc->frame->extended_data[wc->ch_offset];
1573    if (s->stereo)
1574        samples_r = wc->frame->extended_data[wc->ch_offset + 1];
1575
1576    wc->ch_offset += 1 + s->stereo;
1577
1578    if (s->stereo_in) {
1579        if (got_dsd) {
1580            if (dsd_mode == 3) {
1581                ret = wv_unpack_dsd_high(s, samples_l, samples_r);
1582            } else if (dsd_mode == 1) {
1583                ret = wv_unpack_dsd_fast(s, samples_l, samples_r);
1584            } else {
1585                ret = wv_unpack_dsd_copy(s, samples_l, samples_r);
1586            }
1587        } else {
1588            ret = wv_unpack_stereo(s, &s->gb, samples_l, samples_r, avctx->sample_fmt);
1589        }
1590        if (ret < 0)
1591            return ret;
1592    } else {
1593        if (got_dsd) {
1594            if (dsd_mode == 3) {
1595                ret = wv_unpack_dsd_high(s, samples_l, NULL);
1596            } else if (dsd_mode == 1) {
1597                ret = wv_unpack_dsd_fast(s, samples_l, NULL);
1598            } else {
1599                ret = wv_unpack_dsd_copy(s, samples_l, NULL);
1600            }
1601        } else {
1602            ret = wv_unpack_mono(s, &s->gb, samples_l, avctx->sample_fmt);
1603        }
1604        if (ret < 0)
1605            return ret;
1606
1607        if (s->stereo)
1608            memcpy(samples_r, samples_l, bpp * s->samples);
1609    }
1610
1611    return 0;
1612}
1613
1614static void wavpack_decode_flush(AVCodecContext *avctx)
1615{
1616    WavpackContext *s = avctx->priv_data;
1617
1618    wv_dsd_reset(s, 0);
1619}
1620
1621static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
1622{
1623    WavpackContext *s  = avctx->priv_data;
1624    AVFrame *frame = frmptr;
1625
1626    ff_dsd2pcm_translate (&s->dsdctx [jobnr], s->samples, 0,
1627        (uint8_t *)frame->extended_data[jobnr], 4,
1628        (float *)frame->extended_data[jobnr], 1);
1629
1630    return 0;
1631}
1632
1633static int wavpack_decode_frame(AVCodecContext *avctx, AVFrame *rframe,
1634                                int *got_frame_ptr, AVPacket *avpkt)
1635{
1636    WavpackContext *s  = avctx->priv_data;
1637    const uint8_t *buf = avpkt->data;
1638    int buf_size       = avpkt->size;
1639    int frame_size, ret, frame_flags;
1640
1641    if (avpkt->size <= WV_HEADER_SIZE)
1642        return AVERROR_INVALIDDATA;
1643
1644    s->frame     = NULL;
1645    s->block     = 0;
1646    s->ch_offset = 0;
1647
1648    /* determine number of samples */
1649    s->samples  = AV_RL32(buf + 20);
1650    frame_flags = AV_RL32(buf + 24);
1651    if (s->samples <= 0 || s->samples > WV_MAX_SAMPLES) {
1652        av_log(avctx, AV_LOG_ERROR, "Invalid number of samples: %d\n",
1653               s->samples);
1654        return AVERROR_INVALIDDATA;
1655    }
1656
1657    s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
1658
1659    while (buf_size > WV_HEADER_SIZE) {
1660        frame_size = AV_RL32(buf + 4) - 12;
1661        buf       += 20;
1662        buf_size  -= 20;
1663        if (frame_size <= 0 || frame_size > buf_size) {
1664            av_log(avctx, AV_LOG_ERROR,
1665                   "Block %d has invalid size (size %d vs. %d bytes left)\n",
1666                   s->block, frame_size, buf_size);
1667            ret = AVERROR_INVALIDDATA;
1668            goto error;
1669        }
1670        if ((ret = wavpack_decode_block(avctx, s->block, buf, frame_size)) < 0)
1671            goto error;
1672        s->block++;
1673        buf      += frame_size;
1674        buf_size -= frame_size;
1675    }
1676
1677    if (s->ch_offset != avctx->ch_layout.nb_channels) {
1678        av_log(avctx, AV_LOG_ERROR, "Not enough channels coded in a packet.\n");
1679        ret = AVERROR_INVALIDDATA;
1680        goto error;
1681    }
1682
1683    ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1684    ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1685
1686    if (s->modulation == MODULATION_DSD)
1687        avctx->execute2(avctx, dsd_channel, s->frame, NULL, avctx->ch_layout.nb_channels);
1688
1689    ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1690
1691    if ((ret = av_frame_ref(rframe, s->frame)) < 0)
1692        return ret;
1693
1694    *got_frame_ptr = 1;
1695
1696    return avpkt->size;
1697
1698error:
1699    if (s->frame) {
1700        ff_thread_await_progress(&s->prev_frame, INT_MAX, 0);
1701        ff_thread_release_ext_buffer(avctx, &s->prev_frame);
1702        ff_thread_report_progress(&s->curr_frame, INT_MAX, 0);
1703    }
1704
1705    return ret;
1706}
1707
1708const FFCodec ff_wavpack_decoder = {
1709    .p.name         = "wavpack",
1710    .p.long_name    = NULL_IF_CONFIG_SMALL("WavPack"),
1711    .p.type         = AVMEDIA_TYPE_AUDIO,
1712    .p.id           = AV_CODEC_ID_WAVPACK,
1713    .priv_data_size = sizeof(WavpackContext),
1714    .init           = wavpack_decode_init,
1715    .close          = wavpack_decode_end,
1716    FF_CODEC_DECODE_CB(wavpack_decode_frame),
1717    .flush          = wavpack_decode_flush,
1718    .update_thread_context = ONLY_IF_THREADS_ENABLED(update_thread_context),
1719    .p.capabilities = AV_CODEC_CAP_DR1 | AV_CODEC_CAP_FRAME_THREADS |
1720                      AV_CODEC_CAP_SLICE_THREADS | AV_CODEC_CAP_CHANNEL_CONF,
1721    .caps_internal  = FF_CODEC_CAP_INIT_THREADSAFE | FF_CODEC_CAP_INIT_CLEANUP |
1722                      FF_CODEC_CAP_ALLOCATE_PROGRESS,
1723};
1724