1// SPDX-License-Identifier: Apache-2.0
2// ----------------------------------------------------------------------------
3// Copyright 2011-2024 Arm Limited
4//
5// Licensed under the Apache License, Version 2.0 (the "License"); you may not
6// use this file except in compliance with the License. You may obtain a copy
7// of the License at:
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14// License for the specific language governing permissions and limitations
15// under the License.
16// ----------------------------------------------------------------------------
17
18/**
19 * @brief Functions to decompress a symbolic block.
20 */
21
22#include "astcenc_internal.h"
23
24#include <stdio.h>
25#include <assert.h>
26
27/**
28 * @brief Compute the integer linear interpolation of two color endpoints.
29 *
30 * @param u8_mask       The mask for lanes using decode_unorm8 rather than decode_f16.
31 * @param color0        The endpoint0 color.
32 * @param color1        The endpoint1 color.
33 * @param weights       The interpolation weight (between 0 and 64).
34 *
35 * @return The interpolated color.
36 */
37static vint4 lerp_color_int(
38	vmask4 u8_mask,
39	vint4 color0,
40	vint4 color1,
41	vint4 weights
42) {
43	vint4 weight1 = weights;
44	vint4 weight0 = vint4(64) - weight1;
45
46	vint4 color = (color0 * weight0) + (color1 * weight1) + vint4(32);
47	color = asr<6>(color);
48
49	// For decode_unorm8 values force the codec to bit replicate. This allows the
50	// rest of the codec to assume the full 0xFFFF range for everything and ignore
51	// the decode_mode setting
52	vint4 color_u8 = asr<8>(color) * vint4(257);
53	color = select(color, color_u8, u8_mask);
54
55	return color;
56}
57
58/**
59 * @brief Convert integer color value into a float value for the decoder.
60 *
61 * @param data       The integer color value post-interpolation.
62 * @param lns_mask   If set treat lane as HDR (LNS) else LDR (unorm16).
63 *
64 * @return The float color value.
65 */
66static inline vfloat4 decode_texel(
67	vint4 data,
68	vmask4 lns_mask
69) {
70	vint4 color_lns = vint4::zero();
71	vint4 color_unorm = vint4::zero();
72
73	if (any(lns_mask))
74	{
75		color_lns = lns_to_sf16(data);
76	}
77
78	if (!all(lns_mask))
79	{
80		color_unorm = unorm16_to_sf16(data);
81	}
82
83	// Pick components and then convert to FP16
84	vint4 datai = select(color_unorm, color_lns, lns_mask);
85	return float16_to_float(datai);
86}
87
88/* See header for documentation. */
89void unpack_weights(
90	const block_size_descriptor& bsd,
91	const symbolic_compressed_block& scb,
92	const decimation_info& di,
93	bool is_dual_plane,
94	int weights_plane1[BLOCK_MAX_TEXELS],
95	int weights_plane2[BLOCK_MAX_TEXELS]
96) {
97	// Safe to overshoot as all arrays are allocated to full size
98	if (!is_dual_plane)
99	{
100		// Build full 64-entry weight lookup table
101		vint4 tab0 = vint4::load(scb.weights +  0);
102		vint4 tab1 = vint4::load(scb.weights + 16);
103		vint4 tab2 = vint4::load(scb.weights + 32);
104		vint4 tab3 = vint4::load(scb.weights + 48);
105
106		vint tab0p, tab1p, tab2p, tab3p;
107		vtable_prepare(tab0, tab1, tab2, tab3, tab0p, tab1p, tab2p, tab3p);
108
109		for (unsigned int i = 0; i < bsd.texel_count; i += ASTCENC_SIMD_WIDTH)
110		{
111			vint summed_value(8);
112			vint weight_count(di.texel_weight_count + i);
113			int max_weight_count = hmax(weight_count).lane<0>();
114
115			promise(max_weight_count > 0);
116			for (int j = 0; j < max_weight_count; j++)
117			{
118				vint texel_weights(di.texel_weights_tr[j] + i);
119				vint texel_weights_int(di.texel_weight_contribs_int_tr[j] + i);
120
121				summed_value += vtable_8bt_32bi(tab0p, tab1p, tab2p, tab3p, texel_weights) * texel_weights_int;
122			}
123
124			store(lsr<4>(summed_value), weights_plane1 + i);
125		}
126	}
127	else
128	{
129		// Build a 32-entry weight lookup table per plane
130		// Plane 1
131		vint4 tab0_plane1 = vint4::load(scb.weights +  0);
132		vint4 tab1_plane1 = vint4::load(scb.weights + 16);
133		vint tab0_plane1p, tab1_plane1p;
134		vtable_prepare(tab0_plane1, tab1_plane1, tab0_plane1p, tab1_plane1p);
135
136		// Plane 2
137		vint4 tab0_plane2 = vint4::load(scb.weights + 32);
138		vint4 tab1_plane2 = vint4::load(scb.weights + 48);
139		vint tab0_plane2p, tab1_plane2p;
140		vtable_prepare(tab0_plane2, tab1_plane2, tab0_plane2p, tab1_plane2p);
141
142		for (unsigned int i = 0; i < bsd.texel_count; i += ASTCENC_SIMD_WIDTH)
143		{
144			vint sum_plane1(8);
145			vint sum_plane2(8);
146
147			vint weight_count(di.texel_weight_count + i);
148			int max_weight_count = hmax(weight_count).lane<0>();
149
150			promise(max_weight_count > 0);
151			for (int j = 0; j < max_weight_count; j++)
152			{
153				vint texel_weights(di.texel_weights_tr[j] + i);
154				vint texel_weights_int(di.texel_weight_contribs_int_tr[j] + i);
155
156				sum_plane1 += vtable_8bt_32bi(tab0_plane1p, tab1_plane1p, texel_weights) * texel_weights_int;
157				sum_plane2 += vtable_8bt_32bi(tab0_plane2p, tab1_plane2p, texel_weights) * texel_weights_int;
158			}
159
160			store(lsr<4>(sum_plane1), weights_plane1 + i);
161			store(lsr<4>(sum_plane2), weights_plane2 + i);
162		}
163	}
164}
165
166/**
167 * @brief Return an FP32 NaN value for use in error colors.
168 *
169 * This NaN encoding will turn into 0xFFFF when converted to an FP16 NaN.
170 *
171 * @return The float color value.
172 */
173static float error_color_nan()
174{
175	if32 v;
176	v.u = 0xFFFFE000U;
177	return v.f;
178}
179
180/* See header for documentation. */
181void decompress_symbolic_block(
182	astcenc_profile decode_mode,
183	const block_size_descriptor& bsd,
184	int xpos,
185	int ypos,
186	int zpos,
187	const symbolic_compressed_block& scb,
188	image_block& blk
189) {
190	blk.xpos = xpos;
191	blk.ypos = ypos;
192	blk.zpos = zpos;
193
194	blk.data_min = vfloat4::zero();
195	blk.data_mean = vfloat4::zero();
196	blk.data_max = vfloat4::zero();
197	blk.grayscale = false;
198
199	// If we detected an error-block, blow up immediately.
200	if (scb.block_type == SYM_BTYPE_ERROR)
201	{
202		for (unsigned int i = 0; i < bsd.texel_count; i++)
203		{
204			blk.data_r[i] = error_color_nan();
205			blk.data_g[i] = error_color_nan();
206			blk.data_b[i] = error_color_nan();
207			blk.data_a[i] = error_color_nan();
208			blk.rgb_lns[i] = 0;
209			blk.alpha_lns[i] = 0;
210		}
211
212		return;
213	}
214
215	if ((scb.block_type == SYM_BTYPE_CONST_F16) ||
216	    (scb.block_type == SYM_BTYPE_CONST_U16))
217	{
218		vfloat4 color;
219		uint8_t use_lns = 0;
220
221		// UNORM16 constant color block
222		if (scb.block_type == SYM_BTYPE_CONST_U16)
223		{
224			vint4 colori(scb.constant_color);
225
226			// Determine the UNORM8 rounding on the decode
227			vmask4 u8_mask = get_u8_component_mask(decode_mode, blk);
228
229			// The real decoder would just use the top 8 bits, but we rescale
230			// in to a 16-bit value that rounds correctly.
231			vint4 colori_u8 = asr<8>(colori) * 257;
232			colori = select(colori, colori_u8, u8_mask);
233
234			vint4 colorf16 = unorm16_to_sf16(colori);
235			color = float16_to_float(colorf16);
236		}
237		// FLOAT16 constant color block
238		else
239		{
240			switch (decode_mode)
241			{
242			case ASTCENC_PRF_LDR_SRGB:
243			case ASTCENC_PRF_LDR:
244				color = vfloat4(error_color_nan());
245				break;
246			case ASTCENC_PRF_HDR_RGB_LDR_A:
247			case ASTCENC_PRF_HDR:
248				// Constant-color block; unpack from FP16 to FP32.
249				color = float16_to_float(vint4(scb.constant_color));
250				use_lns = 1;
251				break;
252			}
253		}
254
255		for (unsigned int i = 0; i < bsd.texel_count; i++)
256		{
257			blk.data_r[i] = color.lane<0>();
258			blk.data_g[i] = color.lane<1>();
259			blk.data_b[i] = color.lane<2>();
260			blk.data_a[i] = color.lane<3>();
261			blk.rgb_lns[i] = use_lns;
262			blk.alpha_lns[i] = use_lns;
263		}
264
265		return;
266	}
267
268	// Get the appropriate partition-table entry
269	int partition_count = scb.partition_count;
270	const auto& pi = bsd.get_partition_info(partition_count, scb.partition_index);
271
272	// Get the appropriate block descriptors
273	const auto& bm = bsd.get_block_mode(scb.block_mode);
274	const auto& di = bsd.get_decimation_info(bm.decimation_mode);
275
276	bool is_dual_plane = static_cast<bool>(bm.is_dual_plane);
277
278	// Unquantize and undecimate the weights
279	int plane1_weights[BLOCK_MAX_TEXELS];
280	int plane2_weights[BLOCK_MAX_TEXELS];
281	unpack_weights(bsd, scb, di, is_dual_plane, plane1_weights, plane2_weights);
282
283	// Now that we have endpoint colors and weights, we can unpack texel colors
284	int plane2_component = scb.plane2_component;
285	vmask4 plane2_mask = vint4::lane_id() == vint4(plane2_component);
286
287	vmask4 u8_mask = get_u8_component_mask(decode_mode, blk);
288
289	for (int i = 0; i < partition_count; i++)
290	{
291		// Decode the color endpoints for this partition
292		vint4 ep0;
293		vint4 ep1;
294		bool rgb_lns;
295		bool a_lns;
296
297		unpack_color_endpoints(decode_mode,
298		                       scb.color_formats[i],
299		                       scb.color_values[i],
300		                       rgb_lns, a_lns,
301		                       ep0, ep1);
302
303		vmask4 lns_mask(rgb_lns, rgb_lns, rgb_lns, a_lns);
304
305		int texel_count = pi.partition_texel_count[i];
306		for (int j = 0; j < texel_count; j++)
307		{
308			int tix = pi.texels_of_partition[i][j];
309			vint4 weight = select(vint4(plane1_weights[tix]), vint4(plane2_weights[tix]), plane2_mask);
310			vint4 color = lerp_color_int(u8_mask, ep0, ep1, weight);
311			vfloat4 colorf = decode_texel(color, lns_mask);
312
313			blk.data_r[tix] = colorf.lane<0>();
314			blk.data_g[tix] = colorf.lane<1>();
315			blk.data_b[tix] = colorf.lane<2>();
316			blk.data_a[tix] = colorf.lane<3>();
317		}
318	}
319}
320
321#if !defined(ASTCENC_DECOMPRESS_ONLY)
322
323/* See header for documentation. */
324float compute_symbolic_block_difference_2plane(
325	const astcenc_config& config,
326	const block_size_descriptor& bsd,
327	const symbolic_compressed_block& scb,
328	const image_block& blk
329) {
330	// If we detected an error-block, blow up immediately.
331	if (scb.block_type == SYM_BTYPE_ERROR)
332	{
333		return ERROR_CALC_DEFAULT;
334	}
335
336	assert(scb.block_mode >= 0);
337	assert(scb.partition_count == 1);
338	assert(bsd.get_block_mode(scb.block_mode).is_dual_plane == 1);
339
340	// Get the appropriate block descriptor
341	const block_mode& bm = bsd.get_block_mode(scb.block_mode);
342	const decimation_info& di = bsd.get_decimation_info(bm.decimation_mode);
343
344	// Unquantize and undecimate the weights
345	int plane1_weights[BLOCK_MAX_TEXELS];
346	int plane2_weights[BLOCK_MAX_TEXELS];
347	unpack_weights(bsd, scb, di, true, plane1_weights, plane2_weights);
348
349	vmask4 plane2_mask = vint4::lane_id() == vint4(scb.plane2_component);
350
351	vfloat4 summa = vfloat4::zero();
352
353	// Decode the color endpoints for this partition
354	vint4 ep0;
355	vint4 ep1;
356	bool rgb_lns;
357	bool a_lns;
358
359	unpack_color_endpoints(config.profile,
360	                       scb.color_formats[0],
361	                       scb.color_values[0],
362	                       rgb_lns, a_lns,
363	                       ep0, ep1);
364
365	vmask4 u8_mask = get_u8_component_mask(config.profile, blk);
366
367	// Unpack and compute error for each texel in the partition
368	unsigned int texel_count = bsd.texel_count;
369	for (unsigned int i = 0; i < texel_count; i++)
370	{
371		vint4 weight = select(vint4(plane1_weights[i]), vint4(plane2_weights[i]), plane2_mask);
372		vint4 colori = lerp_color_int(u8_mask, ep0, ep1, weight);
373
374		vfloat4 color = int_to_float(colori);
375		vfloat4 oldColor = blk.texel(i);
376
377		// Compare error using a perceptual decode metric for RGBM textures
378		if (config.flags & ASTCENC_FLG_MAP_RGBM)
379		{
380			// Fail encodings that result in zero weight M pixels. Note that this can cause
381			// "interesting" artifacts if we reject all useful encodings - we typically get max
382			// brightness encodings instead which look just as bad. We recommend users apply a
383			// bias to their stored M value, limiting the lower value to 16 or 32 to avoid
384			// getting small M values post-quantization, but we can't prove it would never
385			// happen, especially at low bit rates ...
386			if (color.lane<3>() == 0.0f)
387			{
388				return -ERROR_CALC_DEFAULT;
389			}
390
391			// Compute error based on decoded RGBM color
392			color = vfloat4(
393				color.lane<0>() * color.lane<3>() * config.rgbm_m_scale,
394				color.lane<1>() * color.lane<3>() * config.rgbm_m_scale,
395				color.lane<2>() * color.lane<3>() * config.rgbm_m_scale,
396				1.0f
397			);
398
399			oldColor = vfloat4(
400				oldColor.lane<0>() * oldColor.lane<3>() * config.rgbm_m_scale,
401				oldColor.lane<1>() * oldColor.lane<3>() * config.rgbm_m_scale,
402				oldColor.lane<2>() * oldColor.lane<3>() * config.rgbm_m_scale,
403				1.0f
404			);
405		}
406
407		vfloat4 error = oldColor - color;
408		error = min(abs(error), 1e15f);
409		error = error * error;
410
411		summa += min(dot(error, blk.channel_weight), ERROR_CALC_DEFAULT);
412	}
413
414	return summa.lane<0>();
415}
416
417/* See header for documentation. */
418float compute_symbolic_block_difference_1plane(
419	const astcenc_config& config,
420	const block_size_descriptor& bsd,
421	const symbolic_compressed_block& scb,
422	const image_block& blk
423) {
424	assert(bsd.get_block_mode(scb.block_mode).is_dual_plane == 0);
425
426	// If we detected an error-block, blow up immediately.
427	if (scb.block_type == SYM_BTYPE_ERROR)
428	{
429		return ERROR_CALC_DEFAULT;
430	}
431
432	assert(scb.block_mode >= 0);
433
434	// Get the appropriate partition-table entry
435	unsigned int partition_count = scb.partition_count;
436	const auto& pi = bsd.get_partition_info(partition_count, scb.partition_index);
437
438	// Get the appropriate block descriptor
439	const block_mode& bm = bsd.get_block_mode(scb.block_mode);
440	const decimation_info& di = bsd.get_decimation_info(bm.decimation_mode);
441
442	// Unquantize and undecimate the weights
443	int plane1_weights[BLOCK_MAX_TEXELS];
444	unpack_weights(bsd, scb, di, false, plane1_weights, nullptr);
445
446	vmask4 u8_mask = get_u8_component_mask(config.profile, blk);
447
448	vfloat4 summa = vfloat4::zero();
449	for (unsigned int i = 0; i < partition_count; i++)
450	{
451		// Decode the color endpoints for this partition
452		vint4 ep0;
453		vint4 ep1;
454		bool rgb_lns;
455		bool a_lns;
456
457		unpack_color_endpoints(config.profile,
458		                       scb.color_formats[i],
459		                       scb.color_values[i],
460		                       rgb_lns, a_lns,
461		                       ep0, ep1);
462
463		// Unpack and compute error for each texel in the partition
464		unsigned int texel_count = pi.partition_texel_count[i];
465		for (unsigned int j = 0; j < texel_count; j++)
466		{
467			unsigned int tix = pi.texels_of_partition[i][j];
468			vint4 colori = lerp_color_int(u8_mask, ep0, ep1,
469			                              vint4(plane1_weights[tix]));
470
471			vfloat4 color = int_to_float(colori);
472			vfloat4 oldColor = blk.texel(tix);
473
474			// Compare error using a perceptual decode metric for RGBM textures
475			if (config.flags & ASTCENC_FLG_MAP_RGBM)
476			{
477				// Fail encodings that result in zero weight M pixels. Note that this can cause
478				// "interesting" artifacts if we reject all useful encodings - we typically get max
479				// brightness encodings instead which look just as bad. We recommend users apply a
480				// bias to their stored M value, limiting the lower value to 16 or 32 to avoid
481				// getting small M values post-quantization, but we can't prove it would never
482				// happen, especially at low bit rates ...
483				if (color.lane<3>() == 0.0f)
484				{
485					return -ERROR_CALC_DEFAULT;
486				}
487
488				// Compute error based on decoded RGBM color
489				color = vfloat4(
490					color.lane<0>() * color.lane<3>() * config.rgbm_m_scale,
491					color.lane<1>() * color.lane<3>() * config.rgbm_m_scale,
492					color.lane<2>() * color.lane<3>() * config.rgbm_m_scale,
493					1.0f
494				);
495
496				oldColor = vfloat4(
497					oldColor.lane<0>() * oldColor.lane<3>() * config.rgbm_m_scale,
498					oldColor.lane<1>() * oldColor.lane<3>() * config.rgbm_m_scale,
499					oldColor.lane<2>() * oldColor.lane<3>() * config.rgbm_m_scale,
500					1.0f
501				);
502			}
503
504			vfloat4 error = oldColor - color;
505			error = min(abs(error), 1e15f);
506			error = error * error;
507
508			summa += min(dot(error, blk.channel_weight), ERROR_CALC_DEFAULT);
509		}
510	}
511
512	return summa.lane<0>();
513}
514
515/* See header for documentation. */
516float compute_symbolic_block_difference_1plane_1partition(
517	const astcenc_config& config,
518	const block_size_descriptor& bsd,
519	const symbolic_compressed_block& scb,
520	const image_block& blk
521) {
522	// If we detected an error-block, blow up immediately.
523	if (scb.block_type == SYM_BTYPE_ERROR)
524	{
525		return ERROR_CALC_DEFAULT;
526	}
527
528	assert(scb.block_mode >= 0);
529	assert(bsd.get_partition_info(scb.partition_count, scb.partition_index).partition_count == 1);
530
531	// Get the appropriate block descriptor
532	const block_mode& bm = bsd.get_block_mode(scb.block_mode);
533	const decimation_info& di = bsd.get_decimation_info(bm.decimation_mode);
534
535	// Unquantize and undecimate the weights
536	ASTCENC_ALIGNAS int plane1_weights[BLOCK_MAX_TEXELS];
537	unpack_weights(bsd, scb, di, false, plane1_weights, nullptr);
538
539	// Decode the color endpoints for this partition
540	vint4 ep0;
541	vint4 ep1;
542	bool rgb_lns;
543	bool a_lns;
544
545	unpack_color_endpoints(config.profile,
546	                       scb.color_formats[0],
547	                       scb.color_values[0],
548	                       rgb_lns, a_lns,
549	                       ep0, ep1);
550
551	vmask4 u8_mask = get_u8_component_mask(config.profile, blk);
552
553	// Unpack and compute error for each texel in the partition
554	vfloatacc summav = vfloatacc::zero();
555
556	vint lane_id = vint::lane_id();
557
558	unsigned int texel_count = bsd.texel_count;
559	for (unsigned int i = 0; i < texel_count; i += ASTCENC_SIMD_WIDTH)
560	{
561		// Compute EP1 contribution
562		vint weight1 = vint::loada(plane1_weights + i);
563		vint ep1_r = vint(ep1.lane<0>()) * weight1;
564		vint ep1_g = vint(ep1.lane<1>()) * weight1;
565		vint ep1_b = vint(ep1.lane<2>()) * weight1;
566		vint ep1_a = vint(ep1.lane<3>()) * weight1;
567
568		// Compute EP0 contribution
569		vint weight0 = vint(64) - weight1;
570		vint ep0_r = vint(ep0.lane<0>()) * weight0;
571		vint ep0_g = vint(ep0.lane<1>()) * weight0;
572		vint ep0_b = vint(ep0.lane<2>()) * weight0;
573		vint ep0_a = vint(ep0.lane<3>()) * weight0;
574
575		// Combine contributions
576		vint colori_r = asr<6>(ep0_r + ep1_r + vint(32));
577		vint colori_g = asr<6>(ep0_g + ep1_g + vint(32));
578		vint colori_b = asr<6>(ep0_b + ep1_b + vint(32));
579		vint colori_a = asr<6>(ep0_a + ep1_a + vint(32));
580
581		// If using a U8 decode mode bit replicate top 8 bits
582		// so rest of codec can assume 0xFFFF max range everywhere
583		vint colori_r8 = asr<8>(colori_r) * vint(257);
584		colori_r = select(colori_r, colori_r8, vmask(u8_mask.lane<0>()));
585
586		vint colori_g8 = asr<8>(colori_g) * vint(257);
587		colori_g = select(colori_g, colori_g8, vmask(u8_mask.lane<1>()));
588
589		vint colori_b8 = asr<8>(colori_b) * vint(257);
590		colori_b = select(colori_b, colori_b8, vmask(u8_mask.lane<2>()));
591
592		vint colori_a8 = asr<8>(colori_a) * vint(257);
593		colori_a = select(colori_a, colori_a8, vmask(u8_mask.lane<3>()));
594
595		// Compute color diff
596		vfloat color_r = int_to_float(colori_r);
597		vfloat color_g = int_to_float(colori_g);
598		vfloat color_b = int_to_float(colori_b);
599		vfloat color_a = int_to_float(colori_a);
600
601		vfloat color_orig_r = loada(blk.data_r + i);
602		vfloat color_orig_g = loada(blk.data_g + i);
603		vfloat color_orig_b = loada(blk.data_b + i);
604		vfloat color_orig_a = loada(blk.data_a + i);
605
606		vfloat color_error_r = min(abs(color_orig_r - color_r), vfloat(1e15f));
607		vfloat color_error_g = min(abs(color_orig_g - color_g), vfloat(1e15f));
608		vfloat color_error_b = min(abs(color_orig_b - color_b), vfloat(1e15f));
609		vfloat color_error_a = min(abs(color_orig_a - color_a), vfloat(1e15f));
610
611		// Compute squared error metric
612		color_error_r = color_error_r * color_error_r;
613		color_error_g = color_error_g * color_error_g;
614		color_error_b = color_error_b * color_error_b;
615		color_error_a = color_error_a * color_error_a;
616
617		vfloat metric = color_error_r * blk.channel_weight.lane<0>()
618		              + color_error_g * blk.channel_weight.lane<1>()
619		              + color_error_b * blk.channel_weight.lane<2>()
620		              + color_error_a * blk.channel_weight.lane<3>();
621
622		// Mask off bad lanes
623		vmask mask = lane_id < vint(texel_count);
624		lane_id += vint(ASTCENC_SIMD_WIDTH);
625		haccumulate(summav, metric, mask);
626	}
627
628	return hadd_s(summav);
629}
630
631#endif
632