1// SPDX-License-Identifier: Apache-2.0
2// ----------------------------------------------------------------------------
3// Copyright 2020-2024 Arm Limited
4//
5// Licensed under the Apache License, Version 2.0 (the "License"); you may not
6// use this file except in compliance with the License. You may obtain a copy
7// of the License at:
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14// License for the specific language governing permissions and limitations
15// under the License.
16// ----------------------------------------------------------------------------
17
18/**
19 * @brief Generic 4x32-bit vector functions.
20 *
21 * This module implements generic 4-wide vector functions that are valid for
22 * all instruction sets, typically implemented using lower level 4-wide
23 * operations that are ISA-specific.
24 */
25
26#ifndef ASTC_VECMATHLIB_COMMON_4_H_INCLUDED
27#define ASTC_VECMATHLIB_COMMON_4_H_INCLUDED
28
29#ifndef ASTCENC_SIMD_INLINE
30	#error "Include astcenc_vecmathlib.h, do not include directly"
31#endif
32
33#include <cstdio>
34
35// ============================================================================
36// vmask4 operators and functions
37// ============================================================================
38
39/**
40 * @brief True if any lanes are enabled, false otherwise.
41 */
42ASTCENC_SIMD_INLINE bool any(vmask4 a)
43{
44	return mask(a) != 0;
45}
46
47/**
48 * @brief True if all lanes are enabled, false otherwise.
49 */
50ASTCENC_SIMD_INLINE bool all(vmask4 a)
51{
52	return mask(a) == 0xF;
53}
54
55// ============================================================================
56// vint4 operators and functions
57// ============================================================================
58
59/**
60 * @brief Overload: vector by scalar addition.
61 */
62ASTCENC_SIMD_INLINE vint4 operator+(vint4 a, int b)
63{
64	return a + vint4(b);
65}
66
67/**
68 * @brief Overload: vector by vector incremental addition.
69 */
70ASTCENC_SIMD_INLINE vint4& operator+=(vint4& a, const vint4& b)
71{
72	a = a + b;
73	return a;
74}
75
76/**
77 * @brief Overload: vector by scalar subtraction.
78 */
79ASTCENC_SIMD_INLINE vint4 operator-(vint4 a, int b)
80{
81	return a - vint4(b);
82}
83
84/**
85 * @brief Overload: vector by scalar multiplication.
86 */
87ASTCENC_SIMD_INLINE vint4 operator*(vint4 a, int b)
88{
89	return a * vint4(b);
90}
91
92/**
93 * @brief Overload: vector by scalar bitwise or.
94 */
95ASTCENC_SIMD_INLINE vint4 operator|(vint4 a, int b)
96{
97	return a | vint4(b);
98}
99
100/**
101 * @brief Overload: vector by scalar bitwise and.
102 */
103ASTCENC_SIMD_INLINE vint4 operator&(vint4 a, int b)
104{
105	return a & vint4(b);
106}
107
108/**
109 * @brief Overload: vector by scalar bitwise xor.
110 */
111ASTCENC_SIMD_INLINE vint4 operator^(vint4 a, int b)
112{
113	return a ^ vint4(b);
114}
115
116/**
117 * @brief Return the clamped value between min and max.
118 */
119ASTCENC_SIMD_INLINE vint4 clamp(int minv, int maxv, vint4 a)
120{
121	return min(max(a, vint4(minv)), vint4(maxv));
122}
123
124/**
125 * @brief Return the horizontal sum of RGB vector lanes as a scalar.
126 */
127ASTCENC_SIMD_INLINE int hadd_rgb_s(vint4 a)
128{
129	return a.lane<0>() + a.lane<1>() + a.lane<2>();
130}
131
132// ============================================================================
133// vfloat4 operators and functions
134// ============================================================================
135
136/**
137 * @brief Overload: vector by vector incremental addition.
138 */
139ASTCENC_SIMD_INLINE vfloat4& operator+=(vfloat4& a, const vfloat4& b)
140{
141	a = a + b;
142	return a;
143}
144
145/**
146 * @brief Overload: vector by scalar addition.
147 */
148ASTCENC_SIMD_INLINE vfloat4 operator+(vfloat4 a, float b)
149{
150	return a + vfloat4(b);
151}
152
153/**
154 * @brief Overload: vector by scalar subtraction.
155 */
156ASTCENC_SIMD_INLINE vfloat4 operator-(vfloat4 a, float b)
157{
158	return a - vfloat4(b);
159}
160
161/**
162 * @brief Overload: vector by scalar multiplication.
163 */
164ASTCENC_SIMD_INLINE vfloat4 operator*(vfloat4 a, float b)
165{
166	return a * vfloat4(b);
167}
168
169/**
170 * @brief Overload: scalar by vector multiplication.
171 */
172ASTCENC_SIMD_INLINE vfloat4 operator*(float a, vfloat4 b)
173{
174	return vfloat4(a) * b;
175}
176
177/**
178 * @brief Overload: vector by scalar division.
179 */
180ASTCENC_SIMD_INLINE vfloat4 operator/(vfloat4 a, float b)
181{
182	return a / vfloat4(b);
183}
184
185/**
186 * @brief Overload: scalar by vector division.
187 */
188ASTCENC_SIMD_INLINE vfloat4 operator/(float a, vfloat4 b)
189{
190	return vfloat4(a) / b;
191}
192
193/**
194 * @brief Return the min vector of a vector and a scalar.
195 *
196 * If either lane value is NaN, @c b will be returned for that lane.
197 */
198ASTCENC_SIMD_INLINE vfloat4 min(vfloat4 a, float b)
199{
200	return min(a, vfloat4(b));
201}
202
203/**
204 * @brief Return the max vector of a vector and a scalar.
205 *
206 * If either lane value is NaN, @c b will be returned for that lane.
207 */
208ASTCENC_SIMD_INLINE vfloat4 max(vfloat4 a, float b)
209{
210	return max(a, vfloat4(b));
211}
212
213/**
214 * @brief Return the clamped value between min and max.
215 *
216 * It is assumed that neither @c min nor @c max are NaN values. If @c a is NaN
217 * then @c min will be returned for that lane.
218 */
219ASTCENC_SIMD_INLINE vfloat4 clamp(float minv, float maxv, vfloat4 a)
220{
221	// Do not reorder - second operand will return if either is NaN
222	return min(max(a, minv), maxv);
223}
224
225/**
226 * @brief Return the clamped value between 0.0f and max.
227 *
228 * It is assumed that  @c max is not a NaN value. If @c a is NaN then zero will
229 * be returned for that lane.
230 */
231ASTCENC_SIMD_INLINE vfloat4 clampz(float maxv, vfloat4 a)
232{
233	// Do not reorder - second operand will return if either is NaN
234	return min(max(a, vfloat4::zero()), maxv);
235}
236
237/**
238 * @brief Return the clamped value between 0.0f and 1.0f.
239 *
240 * If @c a is NaN then zero will be returned for that lane.
241 */
242ASTCENC_SIMD_INLINE vfloat4 clampzo(vfloat4 a)
243{
244	// Do not reorder - second operand will return if either is NaN
245	return min(max(a, vfloat4::zero()), 1.0f);
246}
247
248/**
249 * @brief Return the horizontal minimum of a vector.
250 */
251ASTCENC_SIMD_INLINE float hmin_s(vfloat4 a)
252{
253	return hmin(a).lane<0>();
254}
255
256/**
257 * @brief Return the horizontal min of RGB vector lanes as a scalar.
258 */
259ASTCENC_SIMD_INLINE float hmin_rgb_s(vfloat4 a)
260{
261	a.set_lane<3>(a.lane<0>());
262	return hmin_s(a);
263}
264
265/**
266 * @brief Return the horizontal maximum of a vector.
267 */
268ASTCENC_SIMD_INLINE float hmax_s(vfloat4 a)
269{
270	return hmax(a).lane<0>();
271}
272
273/**
274 * @brief Accumulate lane-wise sums for a vector.
275 */
276ASTCENC_SIMD_INLINE void haccumulate(vfloat4& accum, vfloat4 a)
277{
278	accum = accum + a;
279}
280
281/**
282 * @brief Accumulate lane-wise sums for a masked vector.
283 */
284ASTCENC_SIMD_INLINE void haccumulate(vfloat4& accum, vfloat4 a, vmask4 m)
285{
286	a = select(vfloat4::zero(), a, m);
287	haccumulate(accum, a);
288}
289
290#define ASTCENC_USE_COMMON_GATHERF
291ASTCENC_SIMD_INLINE vfloat4 gatherf(const float* base, const uint8_t* idx)
292{
293	return vfloat4(base[idx[0]], base[idx[1]], base[idx[2]], base[idx[3]]);    // index 0,1,2,3
294}
295
296/**
297 * @brief Return the horizontal sum of RGB vector lanes as a scalar.
298 */
299ASTCENC_SIMD_INLINE float hadd_rgb_s(vfloat4 a)
300{
301	return a.lane<0>() + a.lane<1>() + a.lane<2>();
302}
303
304#if !defined(ASTCENC_USE_NATIVE_ADDV)
305/**
306 * @brief Return the horizontal sum of a vector.
307 */
308ASTCENC_SIMD_INLINE float hadd_rgba_s(vfloat4 a)
309{
310	return a.lane<0>() + a.lane<1>() + a.lane<2>() + a.lane<3>();    // channel 0,1,2,3
311}
312#endif
313
314#if !defined(ASTCENC_USE_NATIVE_DOT_PRODUCT)
315
316/**
317 * @brief Return the dot product for the full 4 lanes, returning scalar.
318 */
319ASTCENC_SIMD_INLINE float dot_s(vfloat4 a, vfloat4 b)
320{
321	vfloat4 m = a * b;
322	return hadd_s(m);
323}
324
325/**
326 * @brief Return the dot product for the full 4 lanes, returning vector.
327 */
328ASTCENC_SIMD_INLINE vfloat4 dot(vfloat4 a, vfloat4 b)
329{
330	vfloat4 m = a * b;
331	return vfloat4(hadd_s(m));
332}
333
334/**
335 * @brief Return the dot product for the bottom 3 lanes, returning scalar.
336 */
337ASTCENC_SIMD_INLINE float dot3_s(vfloat4 a, vfloat4 b)
338{
339	vfloat4 m = a * b;
340	return hadd_rgb_s(m);
341}
342
343/**
344 * @brief Return the dot product for the bottom 3 lanes, returning vector.
345 */
346ASTCENC_SIMD_INLINE vfloat4 dot3(vfloat4 a, vfloat4 b)
347{
348	vfloat4 m = a * b;
349	float d3 = hadd_rgb_s(m);
350	return vfloat4(d3, d3, d3, 0.0f);
351}
352
353#endif
354
355#if !defined(ASTCENC_USE_NATIVE_POPCOUNT)
356
357/**
358 * @brief Population bit count.
359 *
360 * @param v   The value to population count.
361 *
362 * @return The number of 1 bits.
363 */
364static inline int popcount(uint64_t v)
365{
366	uint64_t mask1 = 0x5555555555555555ULL;
367	uint64_t mask2 = 0x3333333333333333ULL;
368	uint64_t mask3 = 0x0F0F0F0F0F0F0F0FULL;
369	v -= (v >> 1) & mask1;
370	v = (v & mask2) + ((v >> 2) & mask2);
371	v += v >> 4;
372	v &= mask3;
373	v *= 0x0101010101010101ULL;
374	v >>= 56;
375	return static_cast<int>(v);
376}
377
378#endif
379
380/**
381 * @brief Apply signed bit transfer.
382 *
383 * @param input0   The first encoded endpoint.
384 * @param input1   The second encoded endpoint.
385 */
386static ASTCENC_SIMD_INLINE void bit_transfer_signed(
387	vint4& input0,
388	vint4& input1
389) {
390	input1 = lsr<1>(input1) | (input0 & 0x80);
391	input0 = lsr<1>(input0) & 0x3F;
392
393	vmask4 mask = (input0 & 0x20) != vint4::zero();
394	input0 = select(input0, input0 - 0x40, mask);
395}
396
397/**
398 * @brief Debug function to print a vector of ints.
399 */
400ASTCENC_SIMD_INLINE void print(vint4 a)
401{
402	ASTCENC_ALIGNAS int v[4];
403	storea(a, v);
404	printf("v4_i32:\n  %8d %8d %8d %8d\n",
405	       v[0], v[1], v[2], v[3]);
406}
407
408/**
409 * @brief Debug function to print a vector of ints.
410 */
411ASTCENC_SIMD_INLINE void printx(vint4 a)
412{
413	ASTCENC_ALIGNAS int v[4];
414	storea(a, v);
415	printf("v4_i32:\n  %08x %08x %08x %08x\n",
416	       v[0], v[1], v[2], v[3]);
417}
418
419/**
420 * @brief Debug function to print a vector of floats.
421 */
422ASTCENC_SIMD_INLINE void print(vfloat4 a)
423{
424	ASTCENC_ALIGNAS float v[4];
425	storea(a, v);
426	printf("v4_f32:\n  %0.4f %0.4f %0.4f %0.4f\n",
427	       static_cast<double>(v[0]), static_cast<double>(v[1]),
428	       static_cast<double>(v[2]), static_cast<double>(v[3]));
429}
430
431/**
432 * @brief Debug function to print a vector of masks.
433 */
434ASTCENC_SIMD_INLINE void print(vmask4 a)
435{
436	print(select(vint4(0), vint4(1), a));
437}
438
439#endif // #ifndef ASTC_VECMATHLIB_COMMON_4_H_INCLUDED
440