xref: /third_party/ffmpeg/libavcodec/cbs_vp9.c (revision cabdff1a)
1/*
2 * This file is part of FFmpeg.
3 *
4 * FFmpeg is free software; you can redistribute it and/or
5 * modify it under the terms of the GNU Lesser General Public
6 * License as published by the Free Software Foundation; either
7 * version 2.1 of the License, or (at your option) any later version.
8 *
9 * FFmpeg is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 * Lesser General Public License for more details.
13 *
14 * You should have received a copy of the GNU Lesser General Public
15 * License along with FFmpeg; if not, write to the Free Software
16 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 */
18
19#include "libavutil/avassert.h"
20
21#include "cbs.h"
22#include "cbs_internal.h"
23#include "cbs_vp9.h"
24
25
26static int cbs_vp9_read_s(CodedBitstreamContext *ctx, GetBitContext *gbc,
27                          int width, const char *name,
28                          const int *subscripts, int32_t *write_to)
29{
30    uint32_t magnitude;
31    int position, sign;
32    int32_t value;
33
34    if (ctx->trace_enable)
35        position = get_bits_count(gbc);
36
37    if (get_bits_left(gbc) < width + 1) {
38        av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid signed value at "
39               "%s: bitstream ended.\n", name);
40        return AVERROR_INVALIDDATA;
41    }
42
43    magnitude = get_bits(gbc, width);
44    sign      = get_bits1(gbc);
45    value     = sign ? -(int32_t)magnitude : magnitude;
46
47    if (ctx->trace_enable) {
48        char bits[33];
49        int i;
50        for (i = 0; i < width; i++)
51            bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
52        bits[i] = sign ? '1' : '0';
53        bits[i + 1] = 0;
54
55        ff_cbs_trace_syntax_element(ctx, position, name, subscripts,
56                                    bits, value);
57    }
58
59    *write_to = value;
60    return 0;
61}
62
63static int cbs_vp9_write_s(CodedBitstreamContext *ctx, PutBitContext *pbc,
64                           int width, const char *name,
65                           const int *subscripts, int32_t value)
66{
67    uint32_t magnitude;
68    int sign;
69
70    if (put_bits_left(pbc) < width + 1)
71        return AVERROR(ENOSPC);
72
73    sign      = value < 0;
74    magnitude = sign ? -value : value;
75
76    if (ctx->trace_enable) {
77        char bits[33];
78        int i;
79        for (i = 0; i < width; i++)
80            bits[i] = magnitude >> (width - i - 1) & 1 ? '1' : '0';
81        bits[i] = sign ? '1' : '0';
82        bits[i + 1] = 0;
83
84        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
85                                    name, subscripts, bits, value);
86    }
87
88    put_bits(pbc, width, magnitude);
89    put_bits(pbc, 1, sign);
90
91    return 0;
92}
93
94static int cbs_vp9_read_increment(CodedBitstreamContext *ctx, GetBitContext *gbc,
95                                  uint32_t range_min, uint32_t range_max,
96                                  const char *name, uint32_t *write_to)
97{
98    uint32_t value;
99    int position, i;
100    char bits[8];
101
102    av_assert0(range_min <= range_max && range_max - range_min < sizeof(bits) - 1);
103    if (ctx->trace_enable)
104        position = get_bits_count(gbc);
105
106    for (i = 0, value = range_min; value < range_max;) {
107        if (get_bits_left(gbc) < 1) {
108            av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid increment value at "
109                   "%s: bitstream ended.\n", name);
110            return AVERROR_INVALIDDATA;
111        }
112        if (get_bits1(gbc)) {
113            bits[i++] = '1';
114            ++value;
115        } else {
116            bits[i++] = '0';
117            break;
118        }
119    }
120
121    if (ctx->trace_enable) {
122        bits[i] = 0;
123        ff_cbs_trace_syntax_element(ctx, position, name, NULL, bits, value);
124    }
125
126    *write_to = value;
127    return 0;
128}
129
130static int cbs_vp9_write_increment(CodedBitstreamContext *ctx, PutBitContext *pbc,
131                                   uint32_t range_min, uint32_t range_max,
132                                   const char *name, uint32_t value)
133{
134    int len;
135
136    av_assert0(range_min <= range_max && range_max - range_min < 8);
137    if (value < range_min || value > range_max) {
138        av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
139               "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
140               name, value, range_min, range_max);
141        return AVERROR_INVALIDDATA;
142    }
143
144    if (value == range_max)
145        len = range_max - range_min;
146    else
147        len = value - range_min + 1;
148    if (put_bits_left(pbc) < len)
149        return AVERROR(ENOSPC);
150
151    if (ctx->trace_enable) {
152        char bits[8];
153        int i;
154        for (i = 0; i < len; i++) {
155            if (range_min + i == value)
156                bits[i] = '0';
157            else
158                bits[i] = '1';
159        }
160        bits[i] = 0;
161        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
162                                    name, NULL, bits, value);
163    }
164
165    if (len > 0)
166        put_bits(pbc, len, (1 << len) - 1 - (value != range_max));
167
168    return 0;
169}
170
171static int cbs_vp9_read_le(CodedBitstreamContext *ctx, GetBitContext *gbc,
172                           int width, const char *name,
173                           const int *subscripts, uint32_t *write_to)
174{
175    uint32_t value;
176    int position, b;
177
178    av_assert0(width % 8 == 0);
179
180    if (ctx->trace_enable)
181        position = get_bits_count(gbc);
182
183    if (get_bits_left(gbc) < width) {
184        av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid le value at "
185               "%s: bitstream ended.\n", name);
186        return AVERROR_INVALIDDATA;
187    }
188
189    value = 0;
190    for (b = 0; b < width; b += 8)
191        value |= get_bits(gbc, 8) << b;
192
193    if (ctx->trace_enable) {
194        char bits[33];
195        int i;
196        for (b = 0; b < width; b += 8)
197            for (i = 0; i < 8; i++)
198                bits[b + i] = value >> (b + i) & 1 ? '1' : '0';
199        bits[b] = 0;
200
201        ff_cbs_trace_syntax_element(ctx, position, name, subscripts,
202                                    bits, value);
203    }
204
205    *write_to = value;
206    return 0;
207}
208
209static int cbs_vp9_write_le(CodedBitstreamContext *ctx, PutBitContext *pbc,
210                            int width, const char *name,
211                            const int *subscripts, uint32_t value)
212{
213    int b;
214
215    av_assert0(width % 8 == 0);
216
217    if (put_bits_left(pbc) < width)
218        return AVERROR(ENOSPC);
219
220    if (ctx->trace_enable) {
221        char bits[33];
222        int i;
223        for (b = 0; b < width; b += 8)
224            for (i = 0; i < 8; i++)
225                bits[b + i] = value >> (b + i) & 1 ? '1' : '0';
226        bits[b] = 0;
227
228        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
229                                    name, subscripts, bits, value);
230    }
231
232    for (b = 0; b < width; b += 8)
233        put_bits(pbc, 8, value >> b & 0xff);
234
235    return 0;
236}
237
238#define HEADER(name) do { \
239        ff_cbs_trace_header(ctx, name); \
240    } while (0)
241
242#define CHECK(call) do { \
243        err = (call); \
244        if (err < 0) \
245            return err; \
246    } while (0)
247
248#define FUNC_NAME(rw, codec, name) cbs_ ## codec ## _ ## rw ## _ ## name
249#define FUNC_VP9(rw, name) FUNC_NAME(rw, vp9, name)
250#define FUNC(name) FUNC_VP9(READWRITE, name)
251
252#define SUBSCRIPTS(subs, ...) (subs > 0 ? ((int[subs + 1]){ subs, __VA_ARGS__ }) : NULL)
253
254#define f(width, name) \
255        xf(width, name, current->name, 0, )
256#define s(width, name) \
257        xs(width, name, current->name, 0, )
258#define fs(width, name, subs, ...) \
259        xf(width, name, current->name, subs, __VA_ARGS__)
260#define ss(width, name, subs, ...) \
261        xs(width, name, current->name, subs, __VA_ARGS__)
262
263#define READ
264#define READWRITE read
265#define RWContext GetBitContext
266
267#define xf(width, name, var, subs, ...) do { \
268        uint32_t value; \
269        CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
270                                   SUBSCRIPTS(subs, __VA_ARGS__), \
271                                   &value, 0, (1 << width) - 1)); \
272        var = value; \
273    } while (0)
274#define xs(width, name, var, subs, ...) do { \
275        int32_t value; \
276        CHECK(cbs_vp9_read_s(ctx, rw, width, #name, \
277                             SUBSCRIPTS(subs, __VA_ARGS__), &value)); \
278        var = value; \
279    } while (0)
280
281
282#define increment(name, min, max) do { \
283        uint32_t value; \
284        CHECK(cbs_vp9_read_increment(ctx, rw, min, max, #name, &value)); \
285        current->name = value; \
286    } while (0)
287
288#define fle(width, name, subs, ...) do { \
289        CHECK(cbs_vp9_read_le(ctx, rw, width, #name, \
290                              SUBSCRIPTS(subs, __VA_ARGS__), &current->name)); \
291    } while (0)
292
293#define delta_q(name) do { \
294        uint8_t delta_coded; \
295        int8_t delta_q; \
296        xf(1, name.delta_coded, delta_coded, 0, ); \
297        if (delta_coded) \
298            xs(4, name.delta_q, delta_q, 0, ); \
299        else \
300            delta_q = 0; \
301        current->name = delta_q; \
302    } while (0)
303
304#define prob(name, subs, ...) do { \
305        uint8_t prob_coded; \
306        uint8_t prob; \
307        xf(1, name.prob_coded, prob_coded, subs, __VA_ARGS__); \
308        if (prob_coded) \
309            xf(8, name.prob, prob, subs, __VA_ARGS__); \
310        else \
311            prob = 255; \
312        current->name = prob; \
313    } while (0)
314
315#define fixed(width, name, value) do { \
316        av_unused uint32_t fixed_value; \
317        CHECK(ff_cbs_read_unsigned(ctx, rw, width, #name, \
318                                   0, &fixed_value, value, value)); \
319    } while (0)
320
321#define infer(name, value) do { \
322        current->name = value; \
323    } while (0)
324
325#define byte_alignment(rw) (get_bits_count(rw) % 8)
326
327#include "cbs_vp9_syntax_template.c"
328
329#undef READ
330#undef READWRITE
331#undef RWContext
332#undef xf
333#undef xs
334#undef increment
335#undef fle
336#undef delta_q
337#undef prob
338#undef fixed
339#undef infer
340#undef byte_alignment
341
342
343#define WRITE
344#define READWRITE write
345#define RWContext PutBitContext
346
347#define xf(width, name, var, subs, ...) do { \
348        CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
349                                    SUBSCRIPTS(subs, __VA_ARGS__), \
350                                    var, 0, (1 << width) - 1)); \
351    } while (0)
352#define xs(width, name, var, subs, ...) do { \
353        CHECK(cbs_vp9_write_s(ctx, rw, width, #name, \
354                              SUBSCRIPTS(subs, __VA_ARGS__), var)); \
355    } while (0)
356
357#define increment(name, min, max) do { \
358        CHECK(cbs_vp9_write_increment(ctx, rw, min, max, #name, current->name)); \
359    } while (0)
360
361#define fle(width, name, subs, ...) do { \
362        CHECK(cbs_vp9_write_le(ctx, rw, width, #name, \
363                               SUBSCRIPTS(subs, __VA_ARGS__), current->name)); \
364    } while (0)
365
366#define delta_q(name) do { \
367        xf(1, name.delta_coded, !!current->name, 0, ); \
368        if (current->name) \
369            xs(4, name.delta_q, current->name, 0, ); \
370    } while (0)
371
372#define prob(name, subs, ...) do { \
373        xf(1, name.prob_coded, current->name != 255, subs, __VA_ARGS__); \
374        if (current->name != 255) \
375            xf(8, name.prob, current->name, subs, __VA_ARGS__); \
376    } while (0)
377
378#define fixed(width, name, value) do { \
379        CHECK(ff_cbs_write_unsigned(ctx, rw, width, #name, \
380                                    0, value, value, value)); \
381    } while (0)
382
383#define infer(name, value) do { \
384        if (current->name != (value)) { \
385            av_log(ctx->log_ctx, AV_LOG_WARNING, "Warning: " \
386                   "%s does not match inferred value: " \
387                   "%"PRId64", but should be %"PRId64".\n", \
388                   #name, (int64_t)current->name, (int64_t)(value)); \
389        } \
390    } while (0)
391
392#define byte_alignment(rw) (put_bits_count(rw) % 8)
393
394#include "cbs_vp9_syntax_template.c"
395
396#undef WRITE
397#undef READWRITE
398#undef RWContext
399#undef xf
400#undef xs
401#undef increment
402#undef fle
403#undef delta_q
404#undef prob
405#undef fixed
406#undef infer
407#undef byte_alignment
408
409
410static int cbs_vp9_split_fragment(CodedBitstreamContext *ctx,
411                                  CodedBitstreamFragment *frag,
412                                  int header)
413{
414    uint8_t superframe_header;
415    int err;
416
417    if (frag->data_size == 0)
418        return AVERROR_INVALIDDATA;
419
420    // Last byte in the packet.
421    superframe_header = frag->data[frag->data_size - 1];
422
423    if ((superframe_header & 0xe0) == 0xc0) {
424        VP9RawSuperframeIndex sfi;
425        GetBitContext gbc;
426        size_t index_size, pos;
427        int i;
428
429        index_size = 2 + (((superframe_header & 0x18) >> 3) + 1) *
430                          ((superframe_header & 0x07) + 1);
431
432        if (index_size > frag->data_size)
433            return AVERROR_INVALIDDATA;
434
435        err = init_get_bits(&gbc, frag->data + frag->data_size - index_size,
436                            8 * index_size);
437        if (err < 0)
438            return err;
439
440        err = cbs_vp9_read_superframe_index(ctx, &gbc, &sfi);
441        if (err < 0)
442            return err;
443
444        pos = 0;
445        for (i = 0; i <= sfi.frames_in_superframe_minus_1; i++) {
446            if (pos + sfi.frame_sizes[i] + index_size > frag->data_size) {
447                av_log(ctx->log_ctx, AV_LOG_ERROR, "Frame %d too large "
448                       "in superframe: %"PRIu32" bytes.\n",
449                       i, sfi.frame_sizes[i]);
450                return AVERROR_INVALIDDATA;
451            }
452
453            err = ff_cbs_append_unit_data(frag, 0,
454                                          frag->data + pos,
455                                          sfi.frame_sizes[i],
456                                          frag->data_ref);
457            if (err < 0)
458                return err;
459
460            pos += sfi.frame_sizes[i];
461        }
462        if (pos + index_size != frag->data_size) {
463            av_log(ctx->log_ctx, AV_LOG_WARNING, "Extra padding at "
464                   "end of superframe: %"SIZE_SPECIFIER" bytes.\n",
465                   frag->data_size - (pos + index_size));
466        }
467
468        return 0;
469
470    } else {
471        err = ff_cbs_append_unit_data(frag, 0,
472                                      frag->data, frag->data_size,
473                                      frag->data_ref);
474        if (err < 0)
475            return err;
476    }
477
478    return 0;
479}
480
481static int cbs_vp9_read_unit(CodedBitstreamContext *ctx,
482                             CodedBitstreamUnit *unit)
483{
484    VP9RawFrame *frame;
485    GetBitContext gbc;
486    int err, pos;
487
488    err = init_get_bits(&gbc, unit->data, 8 * unit->data_size);
489    if (err < 0)
490        return err;
491
492    err = ff_cbs_alloc_unit_content2(ctx, unit);
493    if (err < 0)
494        return err;
495    frame = unit->content;
496
497    err = cbs_vp9_read_frame(ctx, &gbc, frame);
498    if (err < 0)
499        return err;
500
501    pos = get_bits_count(&gbc);
502    av_assert0(pos % 8 == 0);
503    pos /= 8;
504    av_assert0(pos <= unit->data_size);
505
506    if (pos == unit->data_size) {
507        // No data (e.g. a show-existing-frame frame).
508    } else {
509        frame->data_ref = av_buffer_ref(unit->data_ref);
510        if (!frame->data_ref)
511            return AVERROR(ENOMEM);
512
513        frame->data      = unit->data      + pos;
514        frame->data_size = unit->data_size - pos;
515    }
516
517    return 0;
518}
519
520static int cbs_vp9_write_unit(CodedBitstreamContext *ctx,
521                              CodedBitstreamUnit *unit,
522                              PutBitContext *pbc)
523{
524    VP9RawFrame *frame = unit->content;
525    int err;
526
527    err = cbs_vp9_write_frame(ctx, pbc, frame);
528    if (err < 0)
529        return err;
530
531    // Frame must be byte-aligned.
532    av_assert0(put_bits_count(pbc) % 8 == 0);
533
534    if (frame->data) {
535        if (frame->data_size > put_bits_left(pbc) / 8)
536            return AVERROR(ENOSPC);
537
538        flush_put_bits(pbc);
539        memcpy(put_bits_ptr(pbc), frame->data, frame->data_size);
540        skip_put_bytes(pbc, frame->data_size);
541    }
542
543    return 0;
544}
545
546static int cbs_vp9_assemble_fragment(CodedBitstreamContext *ctx,
547                                     CodedBitstreamFragment *frag)
548{
549    int err;
550
551    if (frag->nb_units == 1) {
552        // Output is just the content of the single frame.
553
554        CodedBitstreamUnit *frame = &frag->units[0];
555
556        frag->data_ref = av_buffer_ref(frame->data_ref);
557        if (!frag->data_ref)
558            return AVERROR(ENOMEM);
559
560        frag->data      = frame->data;
561        frag->data_size = frame->data_size;
562
563    } else {
564        // Build superframe out of frames.
565
566        VP9RawSuperframeIndex sfi;
567        PutBitContext pbc;
568        AVBufferRef *ref;
569        uint8_t *data;
570        size_t size, max, pos;
571        int i, size_len;
572
573        if (frag->nb_units > 8) {
574            av_log(ctx->log_ctx, AV_LOG_ERROR, "Too many frames to "
575                   "make superframe: %d.\n", frag->nb_units);
576            return AVERROR(EINVAL);
577        }
578
579        max = 0;
580        for (i = 0; i < frag->nb_units; i++)
581            if (max < frag->units[i].data_size)
582                max = frag->units[i].data_size;
583
584        if (max < 2)
585            size_len = 1;
586        else
587            size_len = av_log2(max) / 8 + 1;
588        av_assert0(size_len <= 4);
589
590        sfi.superframe_marker            = VP9_SUPERFRAME_MARKER;
591        sfi.bytes_per_framesize_minus_1  = size_len - 1;
592        sfi.frames_in_superframe_minus_1 = frag->nb_units - 1;
593
594        size = 2;
595        for (i = 0; i < frag->nb_units; i++) {
596            size += size_len + frag->units[i].data_size;
597            sfi.frame_sizes[i] = frag->units[i].data_size;
598        }
599
600        ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
601        if (!ref)
602            return AVERROR(ENOMEM);
603        data = ref->data;
604        memset(data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);
605
606        pos = 0;
607        for (i = 0; i < frag->nb_units; i++) {
608            av_assert0(size - pos > frag->units[i].data_size);
609            memcpy(data + pos, frag->units[i].data,
610                   frag->units[i].data_size);
611            pos += frag->units[i].data_size;
612        }
613        av_assert0(size - pos == 2 + frag->nb_units * size_len);
614
615        init_put_bits(&pbc, data + pos, size - pos);
616
617        err = cbs_vp9_write_superframe_index(ctx, &pbc, &sfi);
618        if (err < 0) {
619            av_log(ctx->log_ctx, AV_LOG_ERROR, "Failed to write "
620                   "superframe index.\n");
621            av_buffer_unref(&ref);
622            return err;
623        }
624
625        av_assert0(put_bits_left(&pbc) == 0);
626        flush_put_bits(&pbc);
627
628        frag->data_ref  = ref;
629        frag->data      = data;
630        frag->data_size = size;
631    }
632
633    return 0;
634}
635
636static void cbs_vp9_flush(CodedBitstreamContext *ctx)
637{
638    CodedBitstreamVP9Context *vp9 = ctx->priv_data;
639
640    memset(vp9->ref, 0, sizeof(vp9->ref));
641}
642
643static const CodedBitstreamUnitTypeDescriptor cbs_vp9_unit_types[] = {
644    CBS_UNIT_TYPE_INTERNAL_REF(0, VP9RawFrame, data),
645    CBS_UNIT_TYPE_END_OF_LIST
646};
647
648const CodedBitstreamType ff_cbs_type_vp9 = {
649    .codec_id          = AV_CODEC_ID_VP9,
650
651    .priv_data_size    = sizeof(CodedBitstreamVP9Context),
652
653    .unit_types        = cbs_vp9_unit_types,
654
655    .split_fragment    = &cbs_vp9_split_fragment,
656    .read_unit         = &cbs_vp9_read_unit,
657    .write_unit        = &cbs_vp9_write_unit,
658
659    .flush             = &cbs_vp9_flush,
660
661    .assemble_fragment = &cbs_vp9_assemble_fragment,
662};
663