xref: /third_party/ffmpeg/libavcodec/mss2.c (revision cabdff1a)
1/*
2 * Microsoft Screen 2 (aka Windows Media Video V9 Screen) decoder
3 *
4 * This file is part of FFmpeg.
5 *
6 * FFmpeg is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU Lesser General Public
8 * License as published by the Free Software Foundation; either
9 * version 2.1 of the License, or (at your option) any later version.
10 *
11 * FFmpeg is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
14 * Lesser General Public License for more details.
15 *
16 * You should have received a copy of the GNU Lesser General Public
17 * License along with FFmpeg; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 */
20
21/**
22 * @file
23 * Microsoft Screen 2 (aka Windows Media Video V9 Screen) decoder
24 */
25
26#include "libavutil/avassert.h"
27#include "codec_internal.h"
28#include "error_resilience.h"
29#include "internal.h"
30#include "mpeg_er.h"
31#include "mpegvideodec.h"
32#include "msmpeg4dec.h"
33#include "qpeldsp.h"
34#include "vc1.h"
35#include "wmv2data.h"
36#include "mss12.h"
37#include "mss2dsp.h"
38
39typedef struct MSS2Context {
40    VC1Context     v;
41    int            split_position;
42    AVFrame       *last_pic;
43    MSS12Context   c;
44    MSS2DSPContext dsp;
45    SliceContext   sc[2];
46} MSS2Context;
47
48static void arith2_normalise(ArithCoder *c)
49{
50    while ((c->high >> 15) - (c->low >> 15) < 2) {
51        if ((c->low ^ c->high) & 0x10000) {
52            c->high  ^= 0x8000;
53            c->value ^= 0x8000;
54            c->low   ^= 0x8000;
55        }
56        c->high  = (uint16_t)c->high  << 8  | 0xFF;
57        c->value = (uint16_t)c->value << 8  | bytestream2_get_byte(c->gbc.gB);
58        c->low   = (uint16_t)c->low   << 8;
59    }
60}
61
62ARITH_GET_BIT(arith2)
63
64/* L. Stuiver and A. Moffat: "Piecewise Integer Mapping for Arithmetic Coding."
65 * In Proc. 8th Data Compression Conference (DCC '98), pp. 3-12, Mar. 1998 */
66
67static int arith2_get_scaled_value(int value, int n, int range)
68{
69    int split = (n << 1) - range;
70
71    if (value > split)
72        return split + (value - split >> 1);
73    else
74        return value;
75}
76
77static void arith2_rescale_interval(ArithCoder *c, int range,
78                                    int low, int high, int n)
79{
80    int split = (n << 1) - range;
81
82    if (high > split)
83        c->high = split + (high - split << 1);
84    else
85        c->high = high;
86
87    c->high += c->low - 1;
88
89    if (low > split)
90        c->low += split + (low - split << 1);
91    else
92        c->low += low;
93}
94
95static int arith2_get_number(ArithCoder *c, int n)
96{
97    int range = c->high - c->low + 1;
98    int scale = av_log2(range) - av_log2(n);
99    int val;
100
101    if (n << scale > range)
102        scale--;
103
104    n <<= scale;
105
106    val = arith2_get_scaled_value(c->value - c->low, n, range) >> scale;
107
108    arith2_rescale_interval(c, range, val << scale, (val + 1) << scale, n);
109
110    arith2_normalise(c);
111
112    return val;
113}
114
115static int arith2_get_prob(ArithCoder *c, int16_t *probs)
116{
117    int range = c->high - c->low + 1, n = *probs;
118    int scale = av_log2(range) - av_log2(n);
119    int i     = 0, val;
120
121    if (n << scale > range)
122        scale--;
123
124    n <<= scale;
125
126    val = arith2_get_scaled_value(c->value - c->low, n, range) >> scale;
127    while (probs[++i] > val) ;
128
129    arith2_rescale_interval(c, range,
130                            probs[i] << scale, probs[i - 1] << scale, n);
131
132    return i;
133}
134
135ARITH_GET_MODEL_SYM(arith2)
136
137static int arith2_get_consumed_bytes(ArithCoder *c)
138{
139    int diff = (c->high >> 16) - (c->low >> 16);
140    int bp   = bytestream2_tell(c->gbc.gB) - 3 << 3;
141    int bits = 1;
142
143    while (!(diff & 0x80)) {
144        bits++;
145        diff <<= 1;
146    }
147
148    return (bits + bp + 7 >> 3) + ((c->low >> 16) + 1 == c->high >> 16);
149}
150
151static void arith2_init(ArithCoder *c, GetByteContext *gB)
152{
153    c->low           = 0;
154    c->high          = 0xFFFFFF;
155    c->value         = bytestream2_get_be24(gB);
156    c->overread      = 0;
157    c->gbc.gB        = gB;
158    c->get_model_sym = arith2_get_model_sym;
159    c->get_number    = arith2_get_number;
160}
161
162static int decode_pal_v2(MSS12Context *ctx, const uint8_t *buf, int buf_size)
163{
164    int i, ncol;
165    uint32_t *pal = ctx->pal + 256 - ctx->free_colours;
166
167    if (!ctx->free_colours)
168        return 0;
169
170    ncol = *buf++;
171    if (ncol > ctx->free_colours || buf_size < 2 + ncol * 3)
172        return AVERROR_INVALIDDATA;
173    for (i = 0; i < ncol; i++)
174        *pal++ = AV_RB24(buf + 3 * i);
175
176    return 1 + ncol * 3;
177}
178
179static int decode_555(AVCodecContext *avctx, GetByteContext *gB, uint16_t *dst, ptrdiff_t stride,
180                      int keyframe, int w, int h)
181{
182    int last_symbol = 0, repeat = 0, prev_avail = 0;
183
184    if (!keyframe) {
185        int x, y, endx, endy, t;
186
187#define READ_PAIR(a, b)                 \
188    a  = bytestream2_get_byte(gB) << 4; \
189    t  = bytestream2_get_byte(gB);      \
190    a |= t >> 4;                        \
191    b  = (t & 0xF) << 8;                \
192    b |= bytestream2_get_byte(gB);      \
193
194        READ_PAIR(x, endx)
195        READ_PAIR(y, endy)
196
197        if (endx >= w || endy >= h || x > endx || y > endy)
198            return AVERROR_INVALIDDATA;
199        dst += x + stride * y;
200        w    = endx - x + 1;
201        h    = endy - y + 1;
202        if (y)
203            prev_avail = 1;
204    }
205
206    do {
207        uint16_t *p = dst;
208        do {
209            if (repeat-- < 1) {
210                int b = bytestream2_get_byte(gB);
211                if (b < 128)
212                    last_symbol = b << 8 | bytestream2_get_byte(gB);
213                else if (b > 129) {
214                    repeat = 0;
215                    while (b-- > 130) {
216                        if (repeat >= (INT_MAX >> 8) - 1) {
217                            av_log(avctx, AV_LOG_ERROR, "repeat overflow\n");
218                            return AVERROR_INVALIDDATA;
219                        }
220                        repeat = (repeat << 8) + bytestream2_get_byte(gB) + 1;
221                    }
222                    if (last_symbol == -2) {
223                        int skip = FFMIN((unsigned)repeat, dst + w - p);
224                        repeat -= skip;
225                        p      += skip;
226                    }
227                } else
228                    last_symbol = 127 - b;
229            }
230            if (last_symbol >= 0)
231                *p = last_symbol;
232            else if (last_symbol == -1 && prev_avail)
233                *p = *(p - stride);
234        } while (++p < dst + w);
235        dst       += stride;
236        prev_avail = 1;
237    } while (--h);
238
239    return 0;
240}
241
242static int decode_rle(GetBitContext *gb, uint8_t *pal_dst, ptrdiff_t pal_stride,
243                      uint8_t *rgb_dst, ptrdiff_t rgb_stride, uint32_t *pal,
244                      int keyframe, int kf_slipt, int slice, int w, int h)
245{
246    uint8_t bits[270] = { 0 };
247    uint32_t codes[270];
248    VLC vlc;
249
250    int current_length = 0, read_codes = 0, next_code = 0, current_codes = 0;
251    int remaining_codes, surplus_codes, i;
252
253    const int alphabet_size = 270 - keyframe;
254
255    int last_symbol = 0, repeat = 0, prev_avail = 0;
256
257    if (!keyframe) {
258        int x, y, clipw, cliph;
259
260        x     = get_bits(gb, 12);
261        y     = get_bits(gb, 12);
262        clipw = get_bits(gb, 12) + 1;
263        cliph = get_bits(gb, 12) + 1;
264
265        if (x + clipw > w || y + cliph > h)
266            return AVERROR_INVALIDDATA;
267        pal_dst += pal_stride * y + x;
268        rgb_dst += rgb_stride * y + x * 3;
269        w        = clipw;
270        h        = cliph;
271        if (y)
272            prev_avail = 1;
273    } else {
274        if (slice > 0) {
275            pal_dst   += pal_stride * kf_slipt;
276            rgb_dst   += rgb_stride * kf_slipt;
277            prev_avail = 1;
278            h         -= kf_slipt;
279        } else
280            h = kf_slipt;
281    }
282
283    /* read explicit codes */
284    do {
285        while (current_codes--) {
286            int symbol = get_bits(gb, 8);
287            if (symbol >= 204 - keyframe)
288                symbol += 14 - keyframe;
289            else if (symbol > 189)
290                symbol = get_bits1(gb) + (symbol << 1) - 190;
291            if (bits[symbol])
292                return AVERROR_INVALIDDATA;
293            bits[symbol]  = current_length;
294            codes[symbol] = next_code++;
295            read_codes++;
296        }
297        current_length++;
298        next_code     <<= 1;
299        remaining_codes = (1 << current_length) - next_code;
300        current_codes   = get_bits(gb, av_ceil_log2(remaining_codes + 1));
301        if (current_length > 22 || current_codes > remaining_codes)
302            return AVERROR_INVALIDDATA;
303    } while (current_codes != remaining_codes);
304
305    remaining_codes = alphabet_size - read_codes;
306
307    /* determine the minimum length to fit the rest of the alphabet */
308    while ((surplus_codes = (2 << current_length) -
309                            (next_code << 1) - remaining_codes) < 0) {
310        current_length++;
311        next_code <<= 1;
312    }
313
314    /* add the rest of the symbols lexicographically */
315    for (i = 0; i < alphabet_size; i++)
316        if (!bits[i]) {
317            if (surplus_codes-- == 0) {
318                current_length++;
319                next_code <<= 1;
320            }
321            bits[i]  = current_length;
322            codes[i] = next_code++;
323        }
324
325    if (next_code != 1 << current_length)
326        return AVERROR_INVALIDDATA;
327
328    if ((i = init_vlc(&vlc, 9, alphabet_size, bits, 1, 1, codes, 4, 4, 0)) < 0)
329        return i;
330
331    /* frame decode */
332    do {
333        uint8_t *pp = pal_dst;
334        uint8_t *rp = rgb_dst;
335        do {
336            if (repeat-- < 1) {
337                int b = get_vlc2(gb, vlc.table, 9, 3);
338                if (b < 256)
339                    last_symbol = b;
340                else if (b < 268) {
341                    b -= 256;
342                    if (b == 11)
343                        b = get_bits(gb, 4) + 10;
344
345                    if (!b)
346                        repeat = 0;
347                    else
348                        repeat = get_bits(gb, b);
349
350                    repeat += (1 << b) - 1;
351
352                    if (last_symbol == -2) {
353                        int skip = FFMIN(repeat, pal_dst + w - pp);
354                        repeat -= skip;
355                        pp     += skip;
356                        rp     += skip * 3;
357                    }
358                } else
359                    last_symbol = 267 - b;
360            }
361            if (last_symbol >= 0) {
362                *pp = last_symbol;
363                AV_WB24(rp, pal[last_symbol]);
364            } else if (last_symbol == -1 && prev_avail) {
365                *pp = *(pp - pal_stride);
366                memcpy(rp, rp - rgb_stride, 3);
367            }
368            rp += 3;
369        } while (++pp < pal_dst + w);
370        pal_dst   += pal_stride;
371        rgb_dst   += rgb_stride;
372        prev_avail = 1;
373    } while (--h);
374
375    ff_free_vlc(&vlc);
376    return 0;
377}
378
379static int decode_wmv9(AVCodecContext *avctx, const uint8_t *buf, int buf_size,
380                       int x, int y, int w, int h, int wmv9_mask)
381{
382    MSS2Context *ctx  = avctx->priv_data;
383    MSS12Context *c   = &ctx->c;
384    VC1Context *v     = avctx->priv_data;
385    MpegEncContext *s = &v->s;
386    AVFrame *f;
387    int ret;
388
389    ff_mpeg_flush(avctx);
390
391    if ((ret = init_get_bits8(&s->gb, buf, buf_size)) < 0)
392        return ret;
393
394    s->loop_filter = avctx->skip_loop_filter < AVDISCARD_ALL;
395
396    if (ff_vc1_parse_frame_header(v, &s->gb) < 0) {
397        av_log(v->s.avctx, AV_LOG_ERROR, "header error\n");
398        return AVERROR_INVALIDDATA;
399    }
400
401    if (s->pict_type != AV_PICTURE_TYPE_I) {
402        av_log(v->s.avctx, AV_LOG_ERROR, "expected I-frame\n");
403        return AVERROR_INVALIDDATA;
404    }
405
406    avctx->pix_fmt = AV_PIX_FMT_YUV420P;
407
408    if ((ret = ff_mpv_frame_start(s, avctx)) < 0) {
409        av_log(v->s.avctx, AV_LOG_ERROR, "ff_mpv_frame_start error\n");
410        avctx->pix_fmt = AV_PIX_FMT_RGB24;
411        return ret;
412    }
413
414    ff_mpeg_er_frame_start(s);
415
416    v->end_mb_x = (w + 15) >> 4;
417    s->end_mb_y = (h + 15) >> 4;
418    if (v->respic & 1)
419        v->end_mb_x = v->end_mb_x + 1 >> 1;
420    if (v->respic & 2)
421        s->end_mb_y = s->end_mb_y + 1 >> 1;
422
423    ff_vc1_decode_blocks(v);
424
425    if (v->end_mb_x == s->mb_width && s->end_mb_y == s->mb_height) {
426        ff_er_frame_end(&s->er);
427    } else {
428        av_log(v->s.avctx, AV_LOG_WARNING,
429               "disabling error correction due to block count mismatch %dx%d != %dx%d\n",
430               v->end_mb_x, s->end_mb_y, s->mb_width, s->mb_height);
431    }
432
433    ff_mpv_frame_end(s);
434
435    f = s->current_picture.f;
436
437    if (v->respic == 3) {
438        ctx->dsp.upsample_plane(f->data[0], f->linesize[0], w,      h);
439        ctx->dsp.upsample_plane(f->data[1], f->linesize[1], w+1 >> 1, h+1 >> 1);
440        ctx->dsp.upsample_plane(f->data[2], f->linesize[2], w+1 >> 1, h+1 >> 1);
441    } else if (v->respic)
442        avpriv_request_sample(v->s.avctx,
443                              "Asymmetric WMV9 rectangle subsampling");
444
445    av_assert0(f->linesize[1] == f->linesize[2]);
446
447    if (wmv9_mask != -1)
448        ctx->dsp.mss2_blit_wmv9_masked(c->rgb_pic + y * c->rgb_stride + x * 3,
449                                       c->rgb_stride, wmv9_mask,
450                                       c->pal_pic + y * c->pal_stride + x,
451                                       c->pal_stride,
452                                       f->data[0], f->linesize[0],
453                                       f->data[1], f->data[2], f->linesize[1],
454                                       w, h);
455    else
456        ctx->dsp.mss2_blit_wmv9(c->rgb_pic + y * c->rgb_stride + x * 3,
457                                c->rgb_stride,
458                                f->data[0], f->linesize[0],
459                                f->data[1], f->data[2], f->linesize[1],
460                                w, h);
461
462    avctx->pix_fmt = AV_PIX_FMT_RGB24;
463
464    return 0;
465}
466
467struct Rectangle {
468    int coded, x, y, w, h;
469};
470
471#define MAX_WMV9_RECTANGLES 20
472#define ARITH2_PADDING 2
473
474static int mss2_decode_frame(AVCodecContext *avctx, AVFrame *frame,
475                             int *got_frame, AVPacket *avpkt)
476{
477    const uint8_t *buf = avpkt->data;
478    int buf_size       = avpkt->size;
479    MSS2Context *ctx = avctx->priv_data;
480    MSS12Context *c  = &ctx->c;
481    GetBitContext gb;
482    GetByteContext gB;
483    ArithCoder acoder;
484
485    int keyframe, has_wmv9, has_mv, is_rle, is_555, ret;
486
487    struct Rectangle wmv9rects[MAX_WMV9_RECTANGLES], *r;
488    int used_rects = 0, i, implicit_rect = 0, av_uninit(wmv9_mask);
489
490    if ((ret = init_get_bits8(&gb, buf, buf_size)) < 0)
491        return ret;
492
493    if (keyframe = get_bits1(&gb))
494        skip_bits(&gb, 7);
495    has_wmv9 = get_bits1(&gb);
496    has_mv   = keyframe ? 0 : get_bits1(&gb);
497    is_rle   = get_bits1(&gb);
498    is_555   = is_rle && get_bits1(&gb);
499    if (c->slice_split > 0)
500        ctx->split_position = c->slice_split;
501    else if (c->slice_split < 0) {
502        if (get_bits1(&gb)) {
503            if (get_bits1(&gb)) {
504                if (get_bits1(&gb))
505                    ctx->split_position = get_bits(&gb, 16);
506                else
507                    ctx->split_position = get_bits(&gb, 12);
508            } else
509                ctx->split_position = get_bits(&gb, 8) << 4;
510        } else {
511            if (keyframe)
512                ctx->split_position = avctx->height / 2;
513        }
514    } else
515        ctx->split_position = avctx->height;
516
517    if (c->slice_split && (ctx->split_position < 1 - is_555 ||
518                           ctx->split_position > avctx->height - 1))
519        return AVERROR_INVALIDDATA;
520
521    align_get_bits(&gb);
522    buf      += get_bits_count(&gb) >> 3;
523    buf_size -= get_bits_count(&gb) >> 3;
524
525    if (buf_size < 1)
526        return AVERROR_INVALIDDATA;
527
528    if (is_555 && (has_wmv9 || has_mv || c->slice_split && ctx->split_position))
529        return AVERROR_INVALIDDATA;
530
531    avctx->pix_fmt = is_555 ? AV_PIX_FMT_RGB555 : AV_PIX_FMT_RGB24;
532    if (ctx->last_pic->format != avctx->pix_fmt)
533        av_frame_unref(ctx->last_pic);
534
535    if (has_wmv9) {
536        bytestream2_init(&gB, buf, buf_size + ARITH2_PADDING);
537        arith2_init(&acoder, &gB);
538
539        implicit_rect = !arith2_get_bit(&acoder);
540
541        while (arith2_get_bit(&acoder)) {
542            if (used_rects == MAX_WMV9_RECTANGLES)
543                return AVERROR_INVALIDDATA;
544            r = &wmv9rects[used_rects];
545            if (!used_rects)
546                r->x = arith2_get_number(&acoder, avctx->width);
547            else
548                r->x = arith2_get_number(&acoder, avctx->width -
549                                         wmv9rects[used_rects - 1].x) +
550                       wmv9rects[used_rects - 1].x;
551            r->y = arith2_get_number(&acoder, avctx->height);
552            r->w = arith2_get_number(&acoder, avctx->width  - r->x) + 1;
553            r->h = arith2_get_number(&acoder, avctx->height - r->y) + 1;
554            used_rects++;
555        }
556
557        if (implicit_rect && used_rects) {
558            av_log(avctx, AV_LOG_ERROR, "implicit_rect && used_rects > 0\n");
559            return AVERROR_INVALIDDATA;
560        }
561
562        if (implicit_rect) {
563            wmv9rects[0].x = 0;
564            wmv9rects[0].y = 0;
565            wmv9rects[0].w = avctx->width;
566            wmv9rects[0].h = avctx->height;
567
568            used_rects = 1;
569        }
570        for (i = 0; i < used_rects; i++) {
571            if (!implicit_rect && arith2_get_bit(&acoder)) {
572                av_log(avctx, AV_LOG_ERROR, "Unexpected grandchildren\n");
573                return AVERROR_INVALIDDATA;
574            }
575            if (!i) {
576                wmv9_mask = arith2_get_bit(&acoder) - 1;
577                if (!wmv9_mask)
578                    wmv9_mask = arith2_get_number(&acoder, 256);
579            }
580            wmv9rects[i].coded = arith2_get_number(&acoder, 2);
581        }
582
583        buf      += arith2_get_consumed_bytes(&acoder);
584        buf_size -= arith2_get_consumed_bytes(&acoder);
585        if (buf_size < 1)
586            return AVERROR_INVALIDDATA;
587    }
588
589    c->mvX = c->mvY = 0;
590    if (keyframe && !is_555) {
591        if ((i = decode_pal_v2(c, buf, buf_size)) < 0)
592            return AVERROR_INVALIDDATA;
593        buf      += i;
594        buf_size -= i;
595    } else if (has_mv) {
596        buf      += 4;
597        buf_size -= 4;
598        if (buf_size < 1)
599            return AVERROR_INVALIDDATA;
600        c->mvX = AV_RB16(buf - 4) - avctx->width;
601        c->mvY = AV_RB16(buf - 2) - avctx->height;
602    }
603
604    if (c->mvX < 0 || c->mvY < 0) {
605        FFSWAP(uint8_t *, c->pal_pic, c->last_pal_pic);
606
607        if ((ret = ff_get_buffer(avctx, frame, AV_GET_BUFFER_FLAG_REF)) < 0)
608            return ret;
609
610        if (ctx->last_pic->data[0]) {
611            av_assert0(frame->linesize[0] == ctx->last_pic->linesize[0]);
612            c->last_rgb_pic = ctx->last_pic->data[0] +
613                              ctx->last_pic->linesize[0] * (avctx->height - 1);
614        } else {
615            av_log(avctx, AV_LOG_ERROR, "Missing keyframe\n");
616            return AVERROR_INVALIDDATA;
617        }
618    } else {
619        if ((ret = ff_reget_buffer(avctx, ctx->last_pic, 0)) < 0)
620            return ret;
621        if ((ret = av_frame_ref(frame, ctx->last_pic)) < 0)
622            return ret;
623
624        c->last_rgb_pic = NULL;
625    }
626    c->rgb_pic    = frame->data[0] +
627                    frame->linesize[0] * (avctx->height - 1);
628    c->rgb_stride = -frame->linesize[0];
629
630    frame->key_frame = keyframe;
631    frame->pict_type = keyframe ? AV_PICTURE_TYPE_I : AV_PICTURE_TYPE_P;
632
633    if (is_555) {
634        bytestream2_init(&gB, buf, buf_size);
635
636        if (decode_555(avctx, &gB, (uint16_t *)c->rgb_pic, c->rgb_stride >> 1,
637                       keyframe, avctx->width, avctx->height))
638            return AVERROR_INVALIDDATA;
639
640        buf_size -= bytestream2_tell(&gB);
641    } else {
642        if (keyframe) {
643            c->corrupted = 0;
644            ff_mss12_slicecontext_reset(&ctx->sc[0]);
645            if (c->slice_split)
646                ff_mss12_slicecontext_reset(&ctx->sc[1]);
647        }
648        if (is_rle) {
649            if ((ret = init_get_bits8(&gb, buf, buf_size)) < 0)
650                return ret;
651            if (ret = decode_rle(&gb, c->pal_pic, c->pal_stride,
652                                 c->rgb_pic, c->rgb_stride, c->pal, keyframe,
653                                 ctx->split_position, 0,
654                                 avctx->width, avctx->height))
655                return ret;
656            align_get_bits(&gb);
657
658            if (c->slice_split)
659                if (ret = decode_rle(&gb, c->pal_pic, c->pal_stride,
660                                     c->rgb_pic, c->rgb_stride, c->pal, keyframe,
661                                     ctx->split_position, 1,
662                                     avctx->width, avctx->height))
663                    return ret;
664
665            align_get_bits(&gb);
666            buf      += get_bits_count(&gb) >> 3;
667            buf_size -= get_bits_count(&gb) >> 3;
668        } else if (!implicit_rect || wmv9_mask != -1) {
669            if (c->corrupted)
670                return AVERROR_INVALIDDATA;
671            bytestream2_init(&gB, buf, buf_size + ARITH2_PADDING);
672            arith2_init(&acoder, &gB);
673            c->keyframe = keyframe;
674            if (c->corrupted = ff_mss12_decode_rect(&ctx->sc[0], &acoder, 0, 0,
675                                                    avctx->width,
676                                                    ctx->split_position))
677                return AVERROR_INVALIDDATA;
678
679            buf      += arith2_get_consumed_bytes(&acoder);
680            buf_size -= arith2_get_consumed_bytes(&acoder);
681            if (c->slice_split) {
682                if (buf_size < 1)
683                    return AVERROR_INVALIDDATA;
684                bytestream2_init(&gB, buf, buf_size + ARITH2_PADDING);
685                arith2_init(&acoder, &gB);
686                if (c->corrupted = ff_mss12_decode_rect(&ctx->sc[1], &acoder, 0,
687                                                        ctx->split_position,
688                                                        avctx->width,
689                                                        avctx->height - ctx->split_position))
690                    return AVERROR_INVALIDDATA;
691
692                buf      += arith2_get_consumed_bytes(&acoder);
693                buf_size -= arith2_get_consumed_bytes(&acoder);
694            }
695        } else
696            memset(c->pal_pic, 0, c->pal_stride * avctx->height);
697    }
698
699    if (has_wmv9) {
700        for (i = 0; i < used_rects; i++) {
701            int x = wmv9rects[i].x;
702            int y = wmv9rects[i].y;
703            int w = wmv9rects[i].w;
704            int h = wmv9rects[i].h;
705            if (wmv9rects[i].coded) {
706                int WMV9codedFrameSize;
707                if (buf_size < 4 || !(WMV9codedFrameSize = AV_RL24(buf)))
708                    return AVERROR_INVALIDDATA;
709                if (ret = decode_wmv9(avctx, buf + 3, buf_size - 3,
710                                      x, y, w, h, wmv9_mask))
711                    return ret;
712                buf      += WMV9codedFrameSize + 3;
713                buf_size -= WMV9codedFrameSize + 3;
714            } else {
715                uint8_t *dst = c->rgb_pic + y * c->rgb_stride + x * 3;
716                if (wmv9_mask != -1) {
717                    ctx->dsp.mss2_gray_fill_masked(dst, c->rgb_stride,
718                                                   wmv9_mask,
719                                                   c->pal_pic + y * c->pal_stride + x,
720                                                   c->pal_stride,
721                                                   w, h);
722                } else {
723                    do {
724                        memset(dst, 0x80, w * 3);
725                        dst += c->rgb_stride;
726                    } while (--h);
727                }
728            }
729        }
730    }
731
732    if (buf_size)
733        av_log(avctx, AV_LOG_WARNING, "buffer not fully consumed\n");
734
735    if (c->mvX < 0 || c->mvY < 0) {
736        av_frame_unref(ctx->last_pic);
737        ret = av_frame_ref(ctx->last_pic, frame);
738        if (ret < 0)
739            return ret;
740    }
741
742    *got_frame       = 1;
743
744    return avpkt->size;
745}
746
747static av_cold int wmv9_init(AVCodecContext *avctx)
748{
749    VC1Context *v = avctx->priv_data;
750    int ret;
751
752    v->s.avctx    = avctx;
753
754    ff_vc1_init_common(v);
755
756    v->profile = PROFILE_MAIN;
757
758    v->zz_8x4     = ff_wmv2_scantableA;
759    v->zz_4x8     = ff_wmv2_scantableB;
760    v->res_y411   = 0;
761    v->res_sprite = 0;
762
763    v->frmrtq_postproc = 7;
764    v->bitrtq_postproc = 31;
765
766    v->res_x8          = 0;
767    v->multires        = 0;
768    v->res_fasttx      = 1;
769
770    v->fastuvmc        = 0;
771
772    v->extended_mv     = 0;
773
774    v->dquant          = 1;
775    v->vstransform     = 1;
776
777    v->res_transtab    = 0;
778
779    v->overlap         = 0;
780
781    v->resync_marker   = 0;
782    v->rangered        = 0;
783
784    v->s.max_b_frames = avctx->max_b_frames = 0;
785    v->quantizer_mode = 0;
786
787    v->finterpflag = 0;
788
789    v->res_rtm_flag = 1;
790
791    ff_vc1_init_transposed_scantables(v);
792
793    if ((ret = ff_msmpeg4_decode_init(avctx)) < 0 ||
794        (ret = ff_vc1_decode_init_alloc_tables(v)) < 0)
795        return ret;
796
797    /* error concealment */
798    v->s.me.qpel_put = v->s.qdsp.put_qpel_pixels_tab;
799    v->s.me.qpel_avg = v->s.qdsp.avg_qpel_pixels_tab;
800
801    return 0;
802}
803
804static av_cold int mss2_decode_end(AVCodecContext *avctx)
805{
806    MSS2Context *const ctx = avctx->priv_data;
807
808    av_frame_free(&ctx->last_pic);
809
810    ff_mss12_decode_end(&ctx->c);
811    av_freep(&ctx->c.pal_pic);
812    av_freep(&ctx->c.last_pal_pic);
813    ff_vc1_decode_end(avctx);
814
815    return 0;
816}
817
818static av_cold int mss2_decode_init(AVCodecContext *avctx)
819{
820    MSS2Context * const ctx = avctx->priv_data;
821    MSS12Context *c = &ctx->c;
822    int ret;
823    c->avctx = avctx;
824    if (ret = ff_mss12_decode_init(c, 1, &ctx->sc[0], &ctx->sc[1]))
825        return ret;
826    ctx->last_pic   = av_frame_alloc();
827    c->pal_stride   = c->mask_stride;
828    c->pal_pic      = av_mallocz(c->pal_stride * avctx->height);
829    c->last_pal_pic = av_mallocz(c->pal_stride * avctx->height);
830    if (!c->pal_pic || !c->last_pal_pic || !ctx->last_pic) {
831        mss2_decode_end(avctx);
832        return AVERROR(ENOMEM);
833    }
834    if (ret = wmv9_init(avctx)) {
835        mss2_decode_end(avctx);
836        return ret;
837    }
838    ff_mss2dsp_init(&ctx->dsp);
839
840    avctx->pix_fmt = c->free_colours == 127 ? AV_PIX_FMT_RGB555
841                                            : AV_PIX_FMT_RGB24;
842
843
844    return 0;
845}
846
847const FFCodec ff_mss2_decoder = {
848    .p.name         = "mss2",
849    .p.long_name    = NULL_IF_CONFIG_SMALL("MS Windows Media Video V9 Screen"),
850    .p.type         = AVMEDIA_TYPE_VIDEO,
851    .p.id           = AV_CODEC_ID_MSS2,
852    .priv_data_size = sizeof(MSS2Context),
853    .init           = mss2_decode_init,
854    .close          = mss2_decode_end,
855    FF_CODEC_DECODE_CB(mss2_decode_frame),
856    .p.capabilities = AV_CODEC_CAP_DR1,
857    .caps_internal  = FF_CODEC_CAP_INIT_THREADSAFE,
858};
859