1// SPDX-License-Identifier: Apache-2.0
2// ----------------------------------------------------------------------------
3// Copyright 2011-2022 Arm Limited
4//
5// Licensed under the Apache License, Version 2.0 (the "License"); you may not
6// use this file except in compliance with the License. You may obtain a copy
7// of the License at:
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14// License for the specific language governing permissions and limitations
15// under the License.
16// ----------------------------------------------------------------------------
17
18/**
19 * @brief Functions for computing image error metrics.
20 */
21
22#include <cassert>
23#include <cstdio>
24
25#include "astcenccli_internal.h"
26
27/**
28 * @brief An accumulator for errors.
29 */
30class error_accum4
31{
32public:
33	/** @brief The running sum. */
34	double sum_r { 0.0 };
35	double sum_g { 0.0 };
36	double sum_b { 0.0 };
37	double sum_a { 0.0 };
38};
39
40/**
41 * @brief Incremental addition operator for error accumulators.
42 *
43 * @param val The accumulator to increment
44 * @param inc The increment to apply
45 *
46 * @return The updated accumulator
47 */
48static error_accum4& operator+=(
49	error_accum4 &val,
50	vfloat4 inc
51) {
52	val.sum_r += static_cast<double>(inc.lane<0>());
53	val.sum_g += static_cast<double>(inc.lane<1>());
54	val.sum_b += static_cast<double>(inc.lane<2>());
55	val.sum_a += static_cast<double>(inc.lane<3>());
56	return val;
57}
58
59/**
60 * @brief mPSNR tone-mapping operator for HDR images.
61 *
62 * @param val     The color value to tone map
63 * @param fstop   The exposure fstop; should be in range [-125, 125]
64 *
65 * @return The mapped color value in [0.0f, 255.0f] range
66 */
67static float mpsnr_operator(
68	float val,
69	int fstop
70) {
71	if32 p;
72	p.u = 0x3f800000 + (fstop << 23);  // 0x3f800000 is 1.0f
73	val *= p.f;
74	val = powf(val, (1.0f / 2.2f));
75	val *= 255.0f;
76
77	return astc::clamp(val, 0.0f, 255.0f);
78}
79
80/**
81 * @brief mPSNR difference between two values.
82 *
83 * Differences are given as "val1 - val2".
84 *
85 * @param val1       The first color value
86 * @param val2       The second color value
87 * @param fstop_lo   The low exposure fstop; should be in range [-125, 125]
88 * @param fstop_hi   The high exposure fstop; should be in range [-125, 125]
89 *
90 * @return The summed mPSNR difference across all active fstop levels
91 */
92static float mpsnr_sumdiff(
93	float val1,
94	float val2,
95	int fstop_lo,
96	int fstop_hi
97) {
98	float summa = 0.0f;
99	for (int i = fstop_lo; i <= fstop_hi; i++)
100	{
101		float mval1 = mpsnr_operator(val1, i);
102		float mval2 = mpsnr_operator(val2, i);
103		float mdiff = mval1 - mval2;
104		summa += mdiff * mdiff;
105	}
106	return summa;
107}
108
109/* See header for documentation */
110void compute_error_metrics(
111	bool compute_hdr_metrics,
112	bool compute_normal_metrics,
113	int input_components,
114	const astcenc_image* img1,
115	const astcenc_image* img2,
116	int fstop_lo,
117	int fstop_hi
118) {
119	static const int componentmasks[5] { 0x00, 0x07, 0x0C, 0x07, 0x0F };
120	int componentmask = componentmasks[input_components];
121
122	error_accum4 errorsum;
123	error_accum4 alpha_scaled_errorsum;
124	error_accum4 log_errorsum;
125	error_accum4 mpsnr_errorsum;
126	double mean_angular_errorsum = 0.0;
127	double worst_angular_errorsum = 0.0;
128
129	unsigned int dim_x = astc::min(img1->dim_x, img2->dim_x);
130	unsigned int dim_y = astc::min(img1->dim_y, img2->dim_y);
131	unsigned int dim_z = astc::min(img1->dim_z, img2->dim_z);
132
133	if (img1->dim_x != img2->dim_x ||
134	    img1->dim_y != img2->dim_y ||
135	    img1->dim_z != img2->dim_z)
136	{
137		printf("WARNING: Only intersection of images will be compared:\n"
138		       "  Image 1: %dx%dx%d\n"
139		       "  Image 2: %dx%dx%d\n",
140		       img1->dim_x, img1->dim_y, img1->dim_z,
141		       img2->dim_x, img2->dim_y, img2->dim_z);
142	}
143
144	double rgb_peak = 0.0;
145	unsigned int xsize1 = img1->dim_x;
146	unsigned int xsize2 = img2->dim_x;
147
148	for (unsigned int z = 0; z < dim_z; z++)
149	{
150		for (unsigned int y = 0; y < dim_y; y++)
151		{
152			for (unsigned int x = 0; x < dim_x; x++)
153			{
154				vfloat4 color1;
155				vfloat4 color2;
156
157				if (img1->data_type == ASTCENC_TYPE_U8)
158				{
159					uint8_t* data8 = static_cast<uint8_t*>(img1->data[z]);
160
161					color1 = vfloat4(
162					    data8[(4 * xsize1 * y) + (4 * x    )],
163					    data8[(4 * xsize1 * y) + (4 * x + 1)],
164					    data8[(4 * xsize1 * y) + (4 * x + 2)],
165					    data8[(4 * xsize1 * y) + (4 * x + 3)]);
166
167					color1 = color1 / 255.0f;
168				}
169				else if (img1->data_type == ASTCENC_TYPE_F16)
170				{
171					uint16_t* data16 = static_cast<uint16_t*>(img1->data[z]);
172
173					vint4 color1i = vint4(
174					    data16[(4 * xsize1 * y) + (4 * x    )],
175					    data16[(4 * xsize1 * y) + (4 * x + 1)],
176					    data16[(4 * xsize1 * y) + (4 * x + 2)],
177					    data16[(4 * xsize1 * y) + (4 * x + 3)]);
178
179					color1 = float16_to_float(color1i);
180					color1 = clamp(0, 65504.0f, color1);
181				}
182				else // if (img1->data_type == ASTCENC_TYPE_F32)
183				{
184					assert(img1->data_type == ASTCENC_TYPE_F32);
185					float* data32 = static_cast<float*>(img1->data[z]);
186
187					color1 = vfloat4(
188					    data32[(4 * xsize1 * y) + (4 * x    )],
189					    data32[(4 * xsize1 * y) + (4 * x + 1)],
190					    data32[(4 * xsize1 * y) + (4 * x + 2)],
191					    data32[(4 * xsize1 * y) + (4 * x + 3)]);
192
193					color1 = clamp(0, 65504.0f, color1);
194				}
195
196				if (img2->data_type == ASTCENC_TYPE_U8)
197				{
198					uint8_t* data8 = static_cast<uint8_t*>(img2->data[z]);
199
200					color2 = vfloat4(
201					    data8[(4 * xsize2 * y) + (4 * x    )],
202					    data8[(4 * xsize2 * y) + (4 * x + 1)],
203					    data8[(4 * xsize2 * y) + (4 * x + 2)],
204					    data8[(4 * xsize2 * y) + (4 * x + 3)]);
205
206					color2 = color2 / 255.0f;
207				}
208				else if (img2->data_type == ASTCENC_TYPE_F16)
209				{
210					uint16_t* data16 = static_cast<uint16_t*>(img2->data[z]);
211
212					vint4 color2i = vint4(
213					    data16[(4 * xsize2 * y) + (4 * x    )],
214					    data16[(4 * xsize2 * y) + (4 * x + 1)],
215					    data16[(4 * xsize2 * y) + (4 * x + 2)],
216					    data16[(4 * xsize2 * y) + (4 * x + 3)]);
217
218					color2 = float16_to_float(color2i);
219					color2 = clamp(0, 65504.0f, color2);
220				}
221				else // if (img2->data_type == ASTCENC_TYPE_F32)
222				{
223					assert(img2->data_type == ASTCENC_TYPE_F32);
224					float* data32 = static_cast<float*>(img2->data[z]);
225
226					color2 = vfloat4(
227					    data32[(4 * xsize2 * y) + (4 * x    )],
228					    data32[(4 * xsize2 * y) + (4 * x + 1)],
229					    data32[(4 * xsize2 * y) + (4 * x + 2)],
230					    data32[(4 * xsize2 * y) + (4 * x + 3)]);
231
232					color2 = clamp(0, 65504.0f, color2);
233				}
234
235				rgb_peak = astc::max(static_cast<double>(color1.lane<0>()),
236				                     static_cast<double>(color1.lane<1>()),
237				                     static_cast<double>(color1.lane<2>()),
238				                     rgb_peak);
239
240				vfloat4 diffcolor = color1 - color2;
241				vfloat4 diffcolor_sq = diffcolor * diffcolor;
242				errorsum += diffcolor_sq;
243
244				vfloat4 alpha_scaled_diffcolor = vfloat4(
245				    diffcolor.lane<0>() * color1.lane<3>(),
246				    diffcolor.lane<1>() * color1.lane<3>(),
247				    diffcolor.lane<2>() * color1.lane<3>(),
248				    diffcolor.lane<3>());
249
250				vfloat4 alpha_scaled_diffcolor_sq = alpha_scaled_diffcolor * alpha_scaled_diffcolor;
251				alpha_scaled_errorsum += alpha_scaled_diffcolor_sq;
252
253				if (compute_hdr_metrics)
254				{
255					vfloat4 log_input_color1 = log2(color1);
256					vfloat4 log_input_color2 = log2(color2);
257
258					vfloat4 log_diffcolor = log_input_color1 - log_input_color2;
259
260					log_errorsum += log_diffcolor * log_diffcolor;
261
262					vfloat4 mpsnr_error = vfloat4(
263					    mpsnr_sumdiff(color1.lane<0>(), color2.lane<0>(), fstop_lo, fstop_hi),
264					    mpsnr_sumdiff(color1.lane<1>(), color2.lane<1>(), fstop_lo, fstop_hi),
265					    mpsnr_sumdiff(color1.lane<2>(), color2.lane<2>(), fstop_lo, fstop_hi),
266					    mpsnr_sumdiff(color1.lane<3>(), color2.lane<3>(), fstop_lo, fstop_hi));
267
268					mpsnr_errorsum += mpsnr_error;
269				}
270
271				if (compute_normal_metrics)
272				{
273					// Decode the normal vector
274					vfloat4 normal1 = (color1 - 0.5f) * 2.0f;
275					normal1 = normalize_safe(normal1.swz<0, 1, 2>(), unit3());
276
277					vfloat4 normal2 = (color2 - 0.5f) * 2.0f;
278					normal2 = normalize_safe(normal2.swz<0, 1, 2>(), unit3());
279
280					// Float error can push this outside of valid range for acos, so clamp to avoid NaN issues
281					float normal_cos = clamp(-1.0f, 1.0f, dot3(normal1, normal2)).lane<0>();
282					float rad_to_degrees = 180.0f / astc::PI;
283					double error_degrees = std::acos(static_cast<double>(normal_cos)) * static_cast<double>(rad_to_degrees);
284
285					mean_angular_errorsum += error_degrees / (dim_x * dim_y * dim_z);
286					worst_angular_errorsum = astc::max(worst_angular_errorsum, error_degrees);
287				}
288			}
289		}
290	}
291
292	double pixels = static_cast<double>(dim_x * dim_y * dim_z);
293	double samples = 0.0;
294
295	double num = 0.0;
296	double alpha_num = 0.0;
297	double log_num = 0.0;
298	double mpsnr_num = 0.0;
299
300	if (componentmask & 1)
301	{
302		num += errorsum.sum_r;
303		alpha_num += alpha_scaled_errorsum.sum_r;
304		log_num += log_errorsum.sum_r;
305		mpsnr_num += mpsnr_errorsum.sum_r;
306		samples += pixels;
307	}
308
309	if (componentmask & 2)
310	{
311		num += errorsum.sum_g;
312		alpha_num += alpha_scaled_errorsum.sum_g;
313		log_num += log_errorsum.sum_g;
314		mpsnr_num += mpsnr_errorsum.sum_g;
315		samples += pixels;
316	}
317
318	if (componentmask & 4)
319	{
320		num += errorsum.sum_b;
321		alpha_num += alpha_scaled_errorsum.sum_b;
322		log_num += log_errorsum.sum_b;
323		mpsnr_num += mpsnr_errorsum.sum_b;
324		samples += pixels;
325	}
326
327	if (componentmask & 8)
328	{
329		num += errorsum.sum_a;
330		alpha_num += alpha_scaled_errorsum.sum_a;
331		samples += pixels;
332	}
333
334	double denom = samples;
335	double stopcount = static_cast<double>(fstop_hi - fstop_lo + 1);
336	double mpsnr_denom = pixels * 3.0 * stopcount * 255.0 * 255.0;
337
338	double psnr;
339	if (num == 0.0)
340	{
341		psnr = 999.0;
342	}
343	else
344	{
345		psnr = 10.0 * log10(denom / num);
346	}
347
348	double rgb_psnr = psnr;
349
350	printf("Quality metrics\n");
351	printf("===============\n\n");
352
353	if (componentmask & 8)
354	{
355		printf("    PSNR (LDR-RGBA):          %9.4f dB\n", psnr);
356
357		double alpha_psnr;
358		if (alpha_num == 0.0)
359		{
360			alpha_psnr = 999.0;
361		}
362		else
363		{
364			alpha_psnr = 10.0 * log10(denom / alpha_num);
365		}
366		printf("    Alpha-weighted PSNR:      %9.4f dB\n", alpha_psnr);
367
368		double rgb_num = errorsum.sum_r + errorsum.sum_g + errorsum.sum_b;
369		if (rgb_num == 0.0)
370		{
371			rgb_psnr = 999.0;
372		}
373		else
374		{
375			rgb_psnr = 10.0 * log10(pixels * 3.0 / rgb_num);
376		}
377		printf("    PSNR (LDR-RGB):           %9.4f dB\n", rgb_psnr);
378	}
379	else
380	{
381		printf("    PSNR (LDR-RGB):           %9.4f dB\n", psnr);
382	}
383
384	if (compute_hdr_metrics)
385	{
386		printf("    PSNR (RGB norm to peak):  %9.4f dB (peak %f)\n",
387		       rgb_psnr + 20.0 * log10(rgb_peak), rgb_peak);
388
389		double mpsnr;
390		if (mpsnr_num == 0.0)
391		{
392			mpsnr = 999.0;
393		}
394		else
395		{
396			mpsnr = 10.0 * log10(mpsnr_denom / mpsnr_num);
397		}
398
399		printf("    mPSNR (RGB):              %9.4f dB (fstops %+d to %+d)\n",
400		       mpsnr, fstop_lo, fstop_hi);
401
402		double logrmse = sqrt(log_num / pixels);
403		printf("    LogRMSE (RGB):            %9.4f\n", logrmse);
404	}
405
406	if (compute_normal_metrics)
407	{
408		printf("    Mean Angular Error:       %9.4f degrees\n", mean_angular_errorsum);
409		printf("    Worst Angular Error:      %9.4f degrees\n", worst_angular_errorsum);
410	}
411
412	printf("\n");
413}
414