1/*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
21 * DEALINGS IN THE SOFTWARE.
22 */
23#include <gtest/gtest.h>
24#include "nir.h"
25#include "nir_builder.h"
26#include "util/half_float.h"
27
28static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
29                           nir_alu_type full_type, int first);
30static void negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
31                   const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
32                   nir_alu_type full_type, unsigned components);
33
34class const_value_negative_equal_test : public ::testing::Test {
35protected:
36   const_value_negative_equal_test()
37   {
38      glsl_type_singleton_init_or_ref();
39
40      memset(c1, 0, sizeof(c1));
41      memset(c2, 0, sizeof(c2));
42   }
43
44   ~const_value_negative_equal_test()
45   {
46      glsl_type_singleton_decref();
47   }
48
49   nir_const_value c1[NIR_MAX_VEC_COMPONENTS];
50   nir_const_value c2[NIR_MAX_VEC_COMPONENTS];
51};
52
53class alu_srcs_negative_equal_test : public ::testing::Test {
54protected:
55   alu_srcs_negative_equal_test()
56   {
57      glsl_type_singleton_init_or_ref();
58
59      static const nir_shader_compiler_options options = { };
60      bld = nir_builder_init_simple_shader(MESA_SHADER_VERTEX, &options,
61                                           "negative equal tests");
62      memset(c1, 0, sizeof(c1));
63      memset(c2, 0, sizeof(c2));
64   }
65
66   ~alu_srcs_negative_equal_test()
67   {
68      ralloc_free(bld.shader);
69      glsl_type_singleton_decref();
70   }
71
72   struct nir_builder bld;
73   nir_const_value c1[NIR_MAX_VEC_COMPONENTS];
74   nir_const_value c2[NIR_MAX_VEC_COMPONENTS];
75};
76
77TEST_F(const_value_negative_equal_test, float32_zero)
78{
79   /* Verify that 0.0 negative-equals 0.0. */
80   EXPECT_TRUE(nir_const_value_negative_equal(c1[0], c1[0], nir_type_float32));
81}
82
83TEST_F(const_value_negative_equal_test, float64_zero)
84{
85   /* Verify that 0.0 negative-equals 0.0. */
86   EXPECT_TRUE(nir_const_value_negative_equal(c1[0], c1[0], nir_type_float64));
87}
88
89/* Compare an object with non-zero values to itself.  This should always be
90 * false.
91 */
92#define compare_with_self(full_type)                                    \
93TEST_F(const_value_negative_equal_test, full_type ## _self)             \
94{                                                                       \
95   count_sequence(c1, full_type, 1);                                    \
96   EXPECT_FALSE(nir_const_value_negative_equal(c1[0], c1[0], full_type)); \
97}
98
99compare_with_self(nir_type_float16)
100compare_with_self(nir_type_float32)
101compare_with_self(nir_type_float64)
102compare_with_self(nir_type_int8)
103compare_with_self(nir_type_uint8)
104compare_with_self(nir_type_int16)
105compare_with_self(nir_type_uint16)
106compare_with_self(nir_type_int32)
107compare_with_self(nir_type_uint32)
108compare_with_self(nir_type_int64)
109compare_with_self(nir_type_uint64)
110#undef compare_with_self
111
112/* Compare an object with the negation of itself.  This should always be true.
113 */
114#define compare_with_negation(full_type)                                \
115TEST_F(const_value_negative_equal_test, full_type ## _trivially_true)   \
116{                                                                       \
117   count_sequence(c1, full_type, 1);                                    \
118   negate(c2, c1, full_type, 1);                                        \
119   EXPECT_TRUE(nir_const_value_negative_equal(c1[0], c2[0], full_type)); \
120}
121
122compare_with_negation(nir_type_float16)
123compare_with_negation(nir_type_float32)
124compare_with_negation(nir_type_float64)
125compare_with_negation(nir_type_int8)
126compare_with_negation(nir_type_uint8)
127compare_with_negation(nir_type_int16)
128compare_with_negation(nir_type_uint16)
129compare_with_negation(nir_type_int32)
130compare_with_negation(nir_type_uint32)
131compare_with_negation(nir_type_int64)
132compare_with_negation(nir_type_uint64)
133#undef compare_with_negation
134
135TEST_F(alu_srcs_negative_equal_test, trivial_float)
136{
137   nir_ssa_def *two = nir_imm_float(&bld, 2.0f);
138   nir_ssa_def *negative_two = nir_imm_float(&bld, -2.0f);
139
140   nir_ssa_def *result = nir_fadd(&bld, two, negative_two);
141   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
142
143   ASSERT_NE((void *) 0, instr);
144   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
145   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
146   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
147}
148
149TEST_F(alu_srcs_negative_equal_test, trivial_int)
150{
151   nir_ssa_def *two = nir_imm_int(&bld, 2);
152   nir_ssa_def *negative_two = nir_imm_int(&bld, -2);
153
154   nir_ssa_def *result = nir_iadd(&bld, two, negative_two);
155   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
156
157   ASSERT_NE((void *) 0, instr);
158   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
159   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
160   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
161}
162
163TEST_F(alu_srcs_negative_equal_test, trivial_negation_float)
164{
165   /* Cannot just do the negation of a nir_load_const_instr because
166    * nir_alu_srcs_negative_equal expects that constant folding will convert
167    * fneg(2.0) to just -2.0.
168    */
169   nir_ssa_def *two = nir_imm_float(&bld, 2.0f);
170   nir_ssa_def *two_plus_two = nir_fadd(&bld, two, two);
171   nir_ssa_def *negation = nir_fneg(&bld, two_plus_two);
172
173   nir_ssa_def *result = nir_fadd(&bld, two_plus_two, negation);
174
175   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
176
177   ASSERT_NE((void *) 0, instr);
178   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
179   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
180   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
181}
182
183TEST_F(alu_srcs_negative_equal_test, trivial_negation_int)
184{
185   /* Cannot just do the negation of a nir_load_const_instr because
186    * nir_alu_srcs_negative_equal expects that constant folding will convert
187    * ineg(2) to just -2.
188    */
189   nir_ssa_def *two = nir_imm_int(&bld, 2);
190   nir_ssa_def *two_plus_two = nir_iadd(&bld, two, two);
191   nir_ssa_def *negation = nir_ineg(&bld, two_plus_two);
192
193   nir_ssa_def *result = nir_iadd(&bld, two_plus_two, negation);
194
195   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
196
197   ASSERT_NE((void *) 0, instr);
198   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
199   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));
200   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));
201}
202
203/* Compare an object with non-zero values to itself.  This should always be
204 * false.
205 */
206#define compare_with_self(full_type)                                    \
207TEST_F(alu_srcs_negative_equal_test, full_type ## _self)                \
208{                                                                       \
209   count_sequence(c1, full_type, 1);                                    \
210   nir_ssa_def *a = nir_build_imm(&bld,                                 \
211                                  NIR_MAX_VEC_COMPONENTS,               \
212                                  nir_alu_type_get_type_size(full_type), \
213                                  c1);                                  \
214   nir_ssa_def *result;                                                 \
215   if (nir_alu_type_get_base_type(full_type) == nir_type_float)         \
216      result = nir_fadd(&bld, a, a);                                    \
217   else                                                                 \
218      result = nir_iadd(&bld, a, a);                                    \
219   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);       \
220   ASSERT_NE((void *) 0, instr);                                        \
221   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));       \
222   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));       \
223   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 0));       \
224   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));       \
225}
226
227compare_with_self(nir_type_float16)
228compare_with_self(nir_type_float32)
229compare_with_self(nir_type_float64)
230compare_with_self(nir_type_int8)
231compare_with_self(nir_type_uint8)
232compare_with_self(nir_type_int16)
233compare_with_self(nir_type_uint16)
234compare_with_self(nir_type_int32)
235compare_with_self(nir_type_uint32)
236compare_with_self(nir_type_int64)
237compare_with_self(nir_type_uint64)
238
239/* Compare an object with the negation of itself.  This should always be true.
240 */
241#define compare_with_negation(full_type)                                \
242TEST_F(alu_srcs_negative_equal_test, full_type ## _trivially_true)      \
243{                                                                       \
244   count_sequence(c1, full_type, 1);                                    \
245   negate(c2, c1, full_type, NIR_MAX_VEC_COMPONENTS);                   \
246   nir_ssa_def *a = nir_build_imm(&bld,                                 \
247                                  NIR_MAX_VEC_COMPONENTS,               \
248                                  nir_alu_type_get_type_size(full_type), \
249                                  c1);                                  \
250   nir_ssa_def *b = nir_build_imm(&bld,                                 \
251                                  NIR_MAX_VEC_COMPONENTS,               \
252                                  nir_alu_type_get_type_size(full_type), \
253                                  c2);                                  \
254   nir_ssa_def *result;                                                 \
255   if (nir_alu_type_get_base_type(full_type) == nir_type_float)         \
256      result = nir_fadd(&bld, a, b);                                    \
257   else                                                                 \
258      result = nir_iadd(&bld, a, b);                                    \
259   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);       \
260   ASSERT_NE((void *) 0, instr);                                        \
261   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 0, 0));       \
262   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));        \
263   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 1, 0));        \
264   EXPECT_FALSE(nir_alu_srcs_negative_equal(instr, instr, 1, 1));       \
265}
266
267compare_with_negation(nir_type_float16)
268compare_with_negation(nir_type_float32)
269compare_with_negation(nir_type_float64)
270compare_with_negation(nir_type_int8)
271compare_with_negation(nir_type_uint8)
272compare_with_negation(nir_type_int16)
273compare_with_negation(nir_type_uint16)
274compare_with_negation(nir_type_int32)
275compare_with_negation(nir_type_uint32)
276compare_with_negation(nir_type_int64)
277compare_with_negation(nir_type_uint64)
278
279TEST_F(alu_srcs_negative_equal_test, swizzle_scalar_to_vector)
280{
281   nir_ssa_def *v = nir_imm_vec2(&bld, 1.0, -1.0);
282   const uint8_t s0[4] = { 0, 0, 0, 0 };
283   const uint8_t s1[4] = { 1, 1, 1, 1 };
284
285   /* We can't use nir_swizzle here because it inserts an extra MOV. */
286   nir_alu_instr *instr = nir_alu_instr_create(bld.shader, nir_op_fadd);
287
288   instr->src[0].src = nir_src_for_ssa(v);
289   instr->src[1].src = nir_src_for_ssa(v);
290
291   memcpy(&instr->src[0].swizzle, s0, sizeof(s0));
292   memcpy(&instr->src[1].swizzle, s1, sizeof(s1));
293
294   nir_builder_alu_instr_finish_and_insert(&bld, instr);
295
296   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
297}
298
299TEST_F(alu_srcs_negative_equal_test, unused_components_mismatch)
300{
301   nir_ssa_def *v1 = nir_imm_vec4(&bld, -2.0, 18.0, 43.0,  1.0);
302   nir_ssa_def *v2 = nir_imm_vec4(&bld,  2.0, 99.0, 76.0, -1.0);
303
304   nir_ssa_def *result = nir_fadd(&bld, v1, v2);
305
306   nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr);
307
308   /* Disable the channels that aren't negations of each other. */
309   nir_register *reg = nir_local_reg_create(bld.impl);
310   nir_instr_rewrite_dest(&instr->instr, &instr->dest.dest, nir_dest_for_reg(reg));
311   instr->dest.write_mask = 8 + 1;
312
313   EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1));
314}
315
316static void
317count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS],
318               nir_alu_type full_type, int first)
319{
320   switch (full_type) {
321   case nir_type_float16:
322      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
323         c[i].u16 = _mesa_float_to_half(float(i + first));
324
325      break;
326
327   case nir_type_float32:
328      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
329         c[i].f32 = float(i + first);
330
331      break;
332
333   case nir_type_float64:
334      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
335         c[i].f64 = double(i + first);
336
337      break;
338
339   case nir_type_int8:
340   case nir_type_uint8:
341      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
342         c[i].i8 = i + first;
343
344      break;
345
346   case nir_type_int16:
347   case nir_type_uint16:
348      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
349         c[i].i16 = i + first;
350
351      break;
352
353   case nir_type_int32:
354   case nir_type_uint32:
355      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
356         c[i].i32 = i + first;
357
358      break;
359
360   case nir_type_int64:
361   case nir_type_uint64:
362      for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
363         c[i].i64 = i + first;
364
365      break;
366
367   case nir_type_bool:
368   default:
369      unreachable("invalid base type");
370   }
371}
372
373static void
374negate(nir_const_value dst[NIR_MAX_VEC_COMPONENTS],
375       const nir_const_value src[NIR_MAX_VEC_COMPONENTS],
376       nir_alu_type full_type, unsigned components)
377{
378   switch (full_type) {
379   case nir_type_float16:
380      for (unsigned i = 0; i < components; i++)
381         dst[i].u16 = _mesa_float_to_half(-_mesa_half_to_float(src[i].u16));
382
383      break;
384
385   case nir_type_float32:
386      for (unsigned i = 0; i < components; i++)
387         dst[i].f32 = -src[i].f32;
388
389      break;
390
391   case nir_type_float64:
392      for (unsigned i = 0; i < components; i++)
393         dst[i].f64 = -src[i].f64;
394
395      break;
396
397   case nir_type_int8:
398   case nir_type_uint8:
399      for (unsigned i = 0; i < components; i++)
400         dst[i].i8 = -src[i].i8;
401
402      break;
403
404   case nir_type_int16:
405   case nir_type_uint16:
406      for (unsigned i = 0; i < components; i++)
407         dst[i].i16 = -src[i].i16;
408
409      break;
410
411   case nir_type_int32:
412   case nir_type_uint32:
413      for (unsigned i = 0; i < components; i++)
414         dst[i].i32 = -src[i].i32;
415
416      break;
417
418   case nir_type_int64:
419   case nir_type_uint64:
420      for (unsigned i = 0; i < components; i++)
421         dst[i].i64 = -src[i].i64;
422
423      break;
424
425   case nir_type_bool:
426   default:
427      unreachable("invalid base type");
428   }
429}
430