1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/ast/call_statement.h"
16 #include "src/resolver/resolver_test_helper.h"
17 
18 namespace tint {
19 namespace resolver {
20 namespace {
21 
22 template <typename T>
23 using DataType = builder::DataType<T>;
24 template <typename T>
25 using vec2 = builder::vec2<T>;
26 template <typename T>
27 using vec3 = builder::vec3<T>;
28 template <typename T>
29 using vec4 = builder::vec4<T>;
30 using f32 = builder::f32;
31 using i32 = builder::i32;
32 using u32 = builder::u32;
33 
34 class ResolverBuiltinsValidationTest : public resolver::TestHelper,
35                                        public testing::Test {};
36 namespace StageTest {
37 struct Params {
38   builder::ast_type_func_ptr type;
39   ast::Builtin builtin;
40   ast::PipelineStage stage;
41   bool is_valid;
42 };
43 
44 template <typename T>
ParamsFor(ast::Builtin builtin, ast::PipelineStage stage, bool is_valid)45 constexpr Params ParamsFor(ast::Builtin builtin,
46                            ast::PipelineStage stage,
47                            bool is_valid) {
48   return Params{DataType<T>::AST, builtin, stage, is_valid};
49 }
50 static constexpr Params cases[] = {
51     ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
52                          ast::PipelineStage::kVertex,
53                          false),
54     ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
55                          ast::PipelineStage::kFragment,
56                          true),
57     ParamsFor<vec4<f32>>(ast::Builtin::kPosition,
58                          ast::PipelineStage::kCompute,
59                          false),
60 
61     ParamsFor<u32>(ast::Builtin::kVertexIndex,
62                    ast::PipelineStage::kVertex,
63                    true),
64     ParamsFor<u32>(ast::Builtin::kVertexIndex,
65                    ast::PipelineStage::kFragment,
66                    false),
67     ParamsFor<u32>(ast::Builtin::kVertexIndex,
68                    ast::PipelineStage::kCompute,
69                    false),
70 
71     ParamsFor<u32>(ast::Builtin::kInstanceIndex,
72                    ast::PipelineStage::kVertex,
73                    true),
74     ParamsFor<u32>(ast::Builtin::kInstanceIndex,
75                    ast::PipelineStage::kFragment,
76                    false),
77     ParamsFor<u32>(ast::Builtin::kInstanceIndex,
78                    ast::PipelineStage::kCompute,
79                    false),
80 
81     ParamsFor<bool>(ast::Builtin::kFrontFacing,
82                     ast::PipelineStage::kVertex,
83                     false),
84     ParamsFor<bool>(ast::Builtin::kFrontFacing,
85                     ast::PipelineStage::kFragment,
86                     true),
87     ParamsFor<bool>(ast::Builtin::kFrontFacing,
88                     ast::PipelineStage::kCompute,
89                     false),
90 
91     ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
92                          ast::PipelineStage::kVertex,
93                          false),
94     ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
95                          ast::PipelineStage::kFragment,
96                          false),
97     ParamsFor<vec3<u32>>(ast::Builtin::kLocalInvocationId,
98                          ast::PipelineStage::kCompute,
99                          true),
100 
101     ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
102                    ast::PipelineStage::kVertex,
103                    false),
104     ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
105                    ast::PipelineStage::kFragment,
106                    false),
107     ParamsFor<u32>(ast::Builtin::kLocalInvocationIndex,
108                    ast::PipelineStage::kCompute,
109                    true),
110 
111     ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
112                          ast::PipelineStage::kVertex,
113                          false),
114     ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
115                          ast::PipelineStage::kFragment,
116                          false),
117     ParamsFor<vec3<u32>>(ast::Builtin::kGlobalInvocationId,
118                          ast::PipelineStage::kCompute,
119                          true),
120 
121     ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
122                          ast::PipelineStage::kVertex,
123                          false),
124     ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
125                          ast::PipelineStage::kFragment,
126                          false),
127     ParamsFor<vec3<u32>>(ast::Builtin::kWorkgroupId,
128                          ast::PipelineStage::kCompute,
129                          true),
130 
131     ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
132                          ast::PipelineStage::kVertex,
133                          false),
134     ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
135                          ast::PipelineStage::kFragment,
136                          false),
137     ParamsFor<vec3<u32>>(ast::Builtin::kNumWorkgroups,
138                          ast::PipelineStage::kCompute,
139                          true),
140 
141     ParamsFor<u32>(ast::Builtin::kSampleIndex,
142                    ast::PipelineStage::kVertex,
143                    false),
144     ParamsFor<u32>(ast::Builtin::kSampleIndex,
145                    ast::PipelineStage::kFragment,
146                    true),
147     ParamsFor<u32>(ast::Builtin::kSampleIndex,
148                    ast::PipelineStage::kCompute,
149                    false),
150 
151     ParamsFor<u32>(ast::Builtin::kSampleMask,
152                    ast::PipelineStage::kVertex,
153                    false),
154     ParamsFor<u32>(ast::Builtin::kSampleMask,
155                    ast::PipelineStage::kFragment,
156                    true),
157     ParamsFor<u32>(ast::Builtin::kSampleMask,
158                    ast::PipelineStage::kCompute,
159                    false),
160 };
161 
162 using ResolverBuiltinsStageTest = ResolverTestWithParam<Params>;
TEST_P(ResolverBuiltinsStageTest, All_input)163 TEST_P(ResolverBuiltinsStageTest, All_input) {
164   const Params& params = GetParam();
165 
166   auto* p = Global("p", ty.vec4<f32>(), ast::StorageClass::kPrivate);
167   auto* input =
168       Param("input", params.type(*this),
169             ast::DecorationList{Builtin(Source{{12, 34}}, params.builtin)});
170   switch (params.stage) {
171     case ast::PipelineStage::kVertex:
172       Func("main", {input}, ty.vec4<f32>(), {Return(p)},
173            {Stage(ast::PipelineStage::kVertex)},
174            {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
175       break;
176     case ast::PipelineStage::kFragment:
177       Func("main", {input}, ty.void_(), {},
178            {Stage(ast::PipelineStage::kFragment)}, {});
179       break;
180     case ast::PipelineStage::kCompute:
181       Func("main", {input}, ty.void_(), {},
182            ast::DecorationList{Stage(ast::PipelineStage::kCompute),
183                                WorkgroupSize(1)});
184       break;
185     default:
186       break;
187   }
188 
189   if (params.is_valid) {
190     EXPECT_TRUE(r()->Resolve()) << r()->error();
191   } else {
192     std::stringstream err;
193     err << "12:34 error: builtin(" << params.builtin << ")";
194     err << " cannot be used in input of " << params.stage << " pipeline stage";
195     EXPECT_FALSE(r()->Resolve());
196     EXPECT_EQ(r()->error(), err.str());
197   }
198 }
199 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
200                          ResolverBuiltinsStageTest,
201                          testing::ValuesIn(cases));
202 
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail)203 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInput_Fail) {
204   // [[stage(fragment)]]
205   // fn fs_main(
206   //   [[builtin(frag_depth)]] fd: f32,
207   // ) -> [[location(0)]] f32 { return 1.0; }
208   auto* fd = Param(
209       "fd", ty.f32(),
210       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
211   Func("fs_main", ast::VariableList{fd}, ty.f32(), {Return(1.0f)},
212        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
213        {Location(0)});
214   EXPECT_FALSE(r()->Resolve());
215   EXPECT_EQ(r()->error(),
216             "12:34 error: builtin(frag_depth) cannot be used in input of "
217             "fragment pipeline stage");
218 }
219 
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail)220 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsInputStruct_Fail) {
221   // struct MyInputs {
222   //   [[builtin(frag_depth)]] ff: f32;
223   // };
224   // [[stage(fragment)]]
225   // fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; }
226 
227   auto* s = Structure(
228       "MyInputs", {Member("frag_depth", ty.f32(),
229                           ast::DecorationList{Builtin(
230                               Source{{12, 34}}, ast::Builtin::kFragDepth)})});
231 
232   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
233        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
234   EXPECT_FALSE(r()->Resolve());
235   EXPECT_EQ(r()->error(),
236             "12:34 error: builtin(frag_depth) cannot be used in input of "
237             "fragment pipeline stage\n"
238             "note: while analysing entry point 'fragShader'");
239 }
240 
TEST_F(ResolverBuiltinsValidationTest, StructBuiltinInsideEntryPoint_Ignored)241 TEST_F(ResolverBuiltinsValidationTest, StructBuiltinInsideEntryPoint_Ignored) {
242   // struct S {
243   //   [[builtin(vertex_index)]] idx: u32;
244   // };
245   // [[stage(fragment)]]
246   // fn fragShader() { var s : S; }
247 
248   Structure("S",
249             {Member("idx", ty.u32(), {Builtin(ast::Builtin::kVertexIndex)})});
250 
251   Func("fragShader", {}, ty.void_(), {Decl(Var("s", ty.type_name("S")))},
252        {Stage(ast::PipelineStage::kFragment)});
253   EXPECT_TRUE(r()->Resolve());
254 }
255 
256 }  // namespace StageTest
257 
TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail)258 TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) {
259   // struct MyInputs {
260   //   [[builtin(kPosition)]] p: vec4<u32>;
261   // };
262   // [[stage(fragment)]]
263   // fn fragShader(is_front: MyInputs) -> [[location(0)]] f32 { return 1.0; }
264 
265   auto* m = Member(
266       "position", ty.vec4<u32>(),
267       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
268   auto* s = Structure("MyInputs", {m});
269   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
270        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
271 
272   EXPECT_FALSE(r()->Resolve());
273   EXPECT_EQ(r()->error(),
274             "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
275 }
276 
TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_ReturnType_Fail)277 TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_ReturnType_Fail) {
278   // [[stage(vertex)]]
279   // fn main() -> [[builtin(position)]] f32 { return 1.0; }
280   Func("main", {}, ty.f32(), {Return(1.0f)},
281        {Stage(ast::PipelineStage::kVertex)},
282        {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
283 
284   EXPECT_FALSE(r()->Resolve());
285   EXPECT_EQ(r()->error(),
286             "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
287 }
288 
TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail)289 TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) {
290   // struct MyInputs {
291   //   [[builtin(kFragDepth)]] p: i32;
292   // };
293   // [[stage(fragment)]]
294   // fn fragShader(is_front: MyInputs) -> [[location(0)]] f32 { return 1.0; }
295 
296   auto* m = Member(
297       "frag_depth", ty.i32(),
298       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
299   auto* s = Structure("MyInputs", {m});
300   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
301        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
302 
303   EXPECT_FALSE(r()->Resolve());
304   EXPECT_EQ(r()->error(),
305             "12:34 error: store type of builtin(frag_depth) must be 'f32'");
306 }
307 
TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail)308 TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail) {
309   // struct MyInputs {
310   //   [[builtin(sample_mask)]] m: f32;
311   // };
312   // [[stage(fragment)]]
313   // fn fragShader(is_front: MyInputs) -> [[location(0)]] f32 { return 1.0; }
314 
315   auto* s = Structure(
316       "MyInputs", {Member("m", ty.f32(),
317                           ast::DecorationList{Builtin(
318                               Source{{12, 34}}, ast::Builtin::kSampleMask)})});
319   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
320        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
321 
322   EXPECT_FALSE(r()->Resolve());
323   EXPECT_EQ(r()->error(),
324             "12:34 error: store type of builtin(sample_mask) must be 'u32'");
325 }
326 
TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_ReturnType_Fail)327 TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_ReturnType_Fail) {
328   // [[stage(fragment)]]
329   // fn main() -> [[builtin(sample_mask)]] i32 { return 1; }
330   Func("main", {}, ty.i32(), {Return(1)},
331        {Stage(ast::PipelineStage::kFragment)},
332        {Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)});
333 
334   EXPECT_FALSE(r()->Resolve());
335   EXPECT_EQ(r()->error(),
336             "12:34 error: store type of builtin(sample_mask) must be 'u32'");
337 }
338 
TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail)339 TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) {
340   // [[stage(fragment)]]
341   // fn fs_main(
342   //   [[builtin(sample_mask)]] arg: bool
343   // ) -> [[location(0)]] f32 { return 1.0; }
344   auto* arg = Param("arg", ty.bool_(),
345                     ast::DecorationList{
346                         Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)});
347   Func("fs_main", ast::VariableList{arg}, ty.f32(), {Return(1.0f)},
348        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
349        {Location(0)});
350   EXPECT_FALSE(r()->Resolve());
351   EXPECT_EQ(r()->error(),
352             "12:34 error: store type of builtin(sample_mask) must be 'u32'");
353 }
354 
TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Struct_Fail)355 TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Struct_Fail) {
356   // struct MyInputs {
357   //   [[builtin(sample_index)]] m: f32;
358   // };
359   // [[stage(fragment)]]
360   // fn fragShader(is_front: MyInputs) -> [[location(0)]] f32 { return 1.0; }
361 
362   auto* s = Structure(
363       "MyInputs", {Member("m", ty.f32(),
364                           ast::DecorationList{Builtin(
365                               Source{{12, 34}}, ast::Builtin::kSampleIndex)})});
366   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
367        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
368 
369   EXPECT_FALSE(r()->Resolve());
370   EXPECT_EQ(r()->error(),
371             "12:34 error: store type of builtin(sample_index) must be 'u32'");
372 }
373 
TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail)374 TEST_F(ResolverBuiltinsValidationTest, SampleIndexIsNotU32_Fail) {
375   // [[stage(fragment)]]
376   // fn fs_main(
377   //   [[builtin(sample_index)]] arg: bool
378   // ) -> [[location(0)]] f32 { return 1.0; }
379   auto* arg = Param("arg", ty.bool_(),
380                     ast::DecorationList{
381                         Builtin(Source{{12, 34}}, ast::Builtin::kSampleIndex)});
382   Func("fs_main", ast::VariableList{arg}, ty.f32(), {Return(1.0f)},
383        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
384        {Location(0)});
385   EXPECT_FALSE(r()->Resolve());
386   EXPECT_EQ(r()->error(),
387             "12:34 error: store type of builtin(sample_index) must be 'u32'");
388 }
389 
TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail)390 TEST_F(ResolverBuiltinsValidationTest, PositionIsNotF32_Fail) {
391   // [[stage(fragment)]]
392   // fn fs_main(
393   //   [[builtin(kPosition)]] p: vec3<f32>,
394   // ) -> [[location(0)]] f32 { return 1.0; }
395   auto* p = Param(
396       "p", ty.vec3<f32>(),
397       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kPosition)});
398   Func("fs_main", ast::VariableList{p}, ty.f32(), {Return(1.0f)},
399        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
400        {Location(0)});
401   EXPECT_FALSE(r()->Resolve());
402   EXPECT_EQ(r()->error(),
403             "12:34 error: store type of builtin(position) must be 'vec4<f32>'");
404 }
405 
TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail)406 TEST_F(ResolverBuiltinsValidationTest, FragDepthIsNotF32_Fail) {
407   // [[stage(fragment)]]
408   // fn fs_main() -> [[builtin(kFragDepth)]] f32 { var fd: i32; return fd; }
409   auto* fd = Var("fd", ty.i32());
410   Func(
411       "fs_main", {}, ty.i32(), {Decl(fd), Return(fd)},
412       ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
413       ast::DecorationList{Builtin(Source{{12, 34}}, ast::Builtin::kFragDepth)});
414   EXPECT_FALSE(r()->Resolve());
415   EXPECT_EQ(r()->error(),
416             "12:34 error: store type of builtin(frag_depth) must be 'f32'");
417 }
418 
TEST_F(ResolverBuiltinsValidationTest, VertexIndexIsNotU32_Fail)419 TEST_F(ResolverBuiltinsValidationTest, VertexIndexIsNotU32_Fail) {
420   // [[stage(vertex)]]
421   // fn main(
422   //   [[builtin(kVertexIndex)]] vi : f32,
423   //   [[builtin(kPosition)]] p :vec4<f32>
424   // ) -> [[builtin(kPosition)]] vec4<f32> { return vec4<f32>(); }
425   auto* p = Param("p", ty.vec4<f32>(),
426                   ast::DecorationList{Builtin(ast::Builtin::kPosition)});
427   auto* vi = Param("vi", ty.f32(),
428                    ast::DecorationList{
429                        Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
430   Func("main", ast::VariableList{vi, p}, ty.vec4<f32>(), {Return(Expr("p"))},
431        ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
432        ast::DecorationList{Builtin(ast::Builtin::kPosition)});
433   EXPECT_FALSE(r()->Resolve());
434   EXPECT_EQ(r()->error(),
435             "12:34 error: store type of builtin(vertex_index) must be 'u32'");
436 }
437 
TEST_F(ResolverBuiltinsValidationTest, InstanceIndexIsNotU32)438 TEST_F(ResolverBuiltinsValidationTest, InstanceIndexIsNotU32) {
439   // [[stage(vertex)]]
440   // fn main(
441   //   [[builtin(kInstanceIndex)]] ii : f32,
442   //   [[builtin(kPosition)]] p :vec4<f32>
443   // ) -> [[builtin(kPosition)]] vec4<f32> { return vec4<f32>(); }
444   auto* p = Param("p", ty.vec4<f32>(),
445                   ast::DecorationList{Builtin(ast::Builtin::kPosition)});
446   auto* ii = Param("ii", ty.f32(),
447                    ast::DecorationList{Builtin(Source{{12, 34}},
448                                                ast::Builtin::kInstanceIndex)});
449   Func("main", ast::VariableList{ii, p}, ty.vec4<f32>(), {Return(Expr("p"))},
450        ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
451        ast::DecorationList{Builtin(ast::Builtin::kPosition)});
452   EXPECT_FALSE(r()->Resolve());
453   EXPECT_EQ(r()->error(),
454             "12:34 error: store type of builtin(instance_index) must be 'u32'");
455 }
456 
TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass)457 TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltin_Pass) {
458   // [[stage(fragment)]]
459   // fn fs_main(
460   //   [[builtin(kPosition)]] p: vec4<f32>,
461   //   [[builtin(front_facing)]] ff: bool,
462   //   [[builtin(sample_index)]] si: u32,
463   //   [[builtin(sample_mask)]] sm : u32
464   // ) -> [[builtin(frag_depth)]] f32 { var fd: f32; return fd; }
465   auto* p = Param("p", ty.vec4<f32>(),
466                   ast::DecorationList{Builtin(ast::Builtin::kPosition)});
467   auto* ff = Param("ff", ty.bool_(),
468                    ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)});
469   auto* si = Param("si", ty.u32(),
470                    ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)});
471   auto* sm = Param("sm", ty.u32(),
472                    ast::DecorationList{Builtin(ast::Builtin::kSampleMask)});
473   auto* var_fd = Var("fd", ty.f32());
474   Func("fs_main", ast::VariableList{p, ff, si, sm}, ty.f32(),
475        {Decl(var_fd), Return(var_fd)},
476        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
477        ast::DecorationList{Builtin(ast::Builtin::kFragDepth)});
478   EXPECT_TRUE(r()->Resolve()) << r()->error();
479 }
480 
TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass)481 TEST_F(ResolverBuiltinsValidationTest, VertexBuiltin_Pass) {
482   // [[stage(vertex)]]
483   // fn main(
484   //   [[builtin(vertex_index)]] vi : u32,
485   //   [[builtin(instance_index)]] ii : u32,
486   // ) -> [[builtin(position)]] vec4<f32> { var p :vec4<f32>; return p; }
487   auto* vi = Param("vi", ty.u32(),
488                    ast::DecorationList{
489                        Builtin(Source{{12, 34}}, ast::Builtin::kVertexIndex)});
490 
491   auto* ii = Param("ii", ty.u32(),
492                    ast::DecorationList{Builtin(Source{{12, 34}},
493                                                ast::Builtin::kInstanceIndex)});
494   auto* p = Var("p", ty.vec4<f32>());
495   Func("main", ast::VariableList{vi, ii}, ty.vec4<f32>(),
496        {
497            Decl(p),
498            Return(p),
499        },
500        ast::DecorationList{Stage(ast::PipelineStage::kVertex)},
501        ast::DecorationList{Builtin(ast::Builtin::kPosition)});
502 
503   EXPECT_TRUE(r()->Resolve()) << r()->error();
504 }
505 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_Pass)506 TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_Pass) {
507   // [[stage(compute), workgroup_size(1)]]
508   // fn main(
509   //   [[builtin(local_invocationId)]] li_id: vec3<u32>,
510   //   [[builtin(local_invocationIndex)]] li_index: u32,
511   //   [[builtin(global_invocationId)]] gi: vec3<u32>,
512   //   [[builtin(workgroup_id)]] wi: vec3<u32>,
513   //   [[builtin(num_workgroups)]] nwgs: vec3<u32>,
514   // ) {}
515 
516   auto* li_id =
517       Param("li_id", ty.vec3<u32>(),
518             ast::DecorationList{Builtin(ast::Builtin::kLocalInvocationId)});
519   auto* li_index =
520       Param("li_index", ty.u32(),
521             ast::DecorationList{Builtin(ast::Builtin::kLocalInvocationIndex)});
522   auto* gi =
523       Param("gi", ty.vec3<u32>(),
524             ast::DecorationList{Builtin(ast::Builtin::kGlobalInvocationId)});
525   auto* wi = Param("wi", ty.vec3<u32>(),
526                    ast::DecorationList{Builtin(ast::Builtin::kWorkgroupId)});
527   auto* nwgs =
528       Param("nwgs", ty.vec3<u32>(),
529             ast::DecorationList{Builtin(ast::Builtin::kNumWorkgroups)});
530 
531   Func("main", ast::VariableList{li_id, li_index, gi, wi, nwgs}, ty.void_(), {},
532        ast::DecorationList{
533            Stage(ast::PipelineStage::kCompute),
534            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
535 
536   EXPECT_TRUE(r()->Resolve()) << r()->error();
537 }
538 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_WorkGroupIdNotVec3U32)539 TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_WorkGroupIdNotVec3U32) {
540   auto* wi = Param("wi", ty.f32(),
541                    ast::DecorationList{
542                        Builtin(Source{{12, 34}}, ast::Builtin::kWorkgroupId)});
543   Func("main", ast::VariableList{wi}, ty.void_(), {},
544        ast::DecorationList{
545            Stage(ast::PipelineStage::kCompute),
546            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
547 
548   EXPECT_FALSE(r()->Resolve());
549   EXPECT_EQ(r()->error(),
550             "12:34 error: store type of builtin(workgroup_id) must be "
551             "'vec3<u32>'");
552 }
553 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_NumWorkgroupsNotVec3U32)554 TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_NumWorkgroupsNotVec3U32) {
555   auto* nwgs = Param("nwgs", ty.f32(),
556                      ast::DecorationList{Builtin(
557                          Source{{12, 34}}, ast::Builtin::kNumWorkgroups)});
558   Func("main", ast::VariableList{nwgs}, ty.void_(), {},
559        ast::DecorationList{
560            Stage(ast::PipelineStage::kCompute),
561            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
562 
563   EXPECT_FALSE(r()->Resolve());
564   EXPECT_EQ(r()->error(),
565             "12:34 error: store type of builtin(num_workgroups) must be "
566             "'vec3<u32>'");
567 }
568 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_GlobalInvocationNotVec3U32)569 TEST_F(ResolverBuiltinsValidationTest,
570        ComputeBuiltin_GlobalInvocationNotVec3U32) {
571   auto* gi = Param("gi", ty.vec3<i32>(),
572                    ast::DecorationList{Builtin(
573                        Source{{12, 34}}, ast::Builtin::kGlobalInvocationId)});
574   Func("main", ast::VariableList{gi}, ty.void_(), {},
575        ast::DecorationList{
576            Stage(ast::PipelineStage::kCompute),
577            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
578 
579   EXPECT_FALSE(r()->Resolve());
580   EXPECT_EQ(r()->error(),
581             "12:34 error: store type of builtin(global_invocation_id) must be "
582             "'vec3<u32>'");
583 }
584 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_LocalInvocationIndexNotU32)585 TEST_F(ResolverBuiltinsValidationTest,
586        ComputeBuiltin_LocalInvocationIndexNotU32) {
587   auto* li_index =
588       Param("li_index", ty.vec3<u32>(),
589             ast::DecorationList{Builtin(Source{{12, 34}},
590                                         ast::Builtin::kLocalInvocationIndex)});
591   Func("main", ast::VariableList{li_index}, ty.void_(), {},
592        ast::DecorationList{
593            Stage(ast::PipelineStage::kCompute),
594            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
595 
596   EXPECT_FALSE(r()->Resolve());
597   EXPECT_EQ(
598       r()->error(),
599       "12:34 error: store type of builtin(local_invocation_index) must be "
600       "'u32'");
601 }
602 
TEST_F(ResolverBuiltinsValidationTest, ComputeBuiltin_LocalInvocationNotVec3U32)603 TEST_F(ResolverBuiltinsValidationTest,
604        ComputeBuiltin_LocalInvocationNotVec3U32) {
605   auto* li_id = Param("li_id", ty.vec2<u32>(),
606                       ast::DecorationList{Builtin(
607                           Source{{12, 34}}, ast::Builtin::kLocalInvocationId)});
608   Func("main", ast::VariableList{li_id}, ty.void_(), {},
609        ast::DecorationList{
610            Stage(ast::PipelineStage::kCompute),
611            WorkgroupSize(Expr(Source{Source::Location{12, 34}}, 2))});
612 
613   EXPECT_FALSE(r()->Resolve());
614   EXPECT_EQ(r()->error(),
615             "12:34 error: store type of builtin(local_invocation_id) must be "
616             "'vec3<u32>'");
617 }
618 
TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass)619 TEST_F(ResolverBuiltinsValidationTest, FragmentBuiltinStruct_Pass) {
620   // Struct MyInputs {
621   //   [[builtin(kPosition)]] p: vec4<f32>;
622   //   [[builtin(frag_depth)]] fd: f32;
623   //   [[builtin(sample_index)]] si: u32;
624   //   [[builtin(sample_mask)]] sm : u32;;
625   // };
626   // [[stage(fragment)]]
627   // fn fragShader(arg: MyInputs) -> [[location(0)]] f32 { return 1.0; }
628 
629   auto* s = Structure(
630       "MyInputs",
631       {Member("position", ty.vec4<f32>(),
632               ast::DecorationList{Builtin(ast::Builtin::kPosition)}),
633        Member("front_facing", ty.bool_(),
634               ast::DecorationList{Builtin(ast::Builtin::kFrontFacing)}),
635        Member("sample_index", ty.u32(),
636               ast::DecorationList{Builtin(ast::Builtin::kSampleIndex)}),
637        Member("sample_mask", ty.u32(),
638               ast::DecorationList{Builtin(ast::Builtin::kSampleMask)})});
639   Func("fragShader", {Param("arg", ty.Of(s))}, ty.f32(), {Return(1.0f)},
640        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
641   EXPECT_TRUE(r()->Resolve()) << r()->error();
642 }
643 
TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail)644 TEST_F(ResolverBuiltinsValidationTest, FrontFacingParamIsNotBool_Fail) {
645   // [[stage(fragment)]]
646   // fn fs_main(
647   //   [[builtin(front_facing)]] is_front: i32;
648   // ) -> [[location(0)]] f32 { return 1.0; }
649 
650   auto* is_front = Param("is_front", ty.i32(),
651                          ast::DecorationList{Builtin(
652                              Source{{12, 34}}, ast::Builtin::kFrontFacing)});
653   Func("fs_main", ast::VariableList{is_front}, ty.f32(), {Return(1.0f)},
654        ast::DecorationList{Stage(ast::PipelineStage::kFragment)},
655        {Location(0)});
656 
657   EXPECT_FALSE(r()->Resolve());
658   EXPECT_EQ(r()->error(),
659             "12:34 error: store type of builtin(front_facing) must be 'bool'");
660 }
661 
TEST_F(ResolverBuiltinsValidationTest, FrontFacingMemberIsNotBool_Fail)662 TEST_F(ResolverBuiltinsValidationTest, FrontFacingMemberIsNotBool_Fail) {
663   // struct MyInputs {
664   //   [[builtin(front_facing)]] pos: f32;
665   // };
666   // [[stage(fragment)]]
667   // fn fragShader(is_front: MyInputs) -> [[location(0)]] f32 { return 1.0; }
668 
669   auto* s = Structure(
670       "MyInputs", {Member("pos", ty.f32(),
671                           ast::DecorationList{Builtin(
672                               Source{{12, 34}}, ast::Builtin::kFrontFacing)})});
673   Func("fragShader", {Param("is_front", ty.Of(s))}, ty.f32(), {Return(1.0f)},
674        {Stage(ast::PipelineStage::kFragment)}, {Location(0)});
675 
676   EXPECT_FALSE(r()->Resolve());
677   EXPECT_EQ(r()->error(),
678             "12:34 error: store type of builtin(front_facing) must be 'bool'");
679 }
680 
TEST_F(ResolverBuiltinsValidationTest, Length_Float_Scalar)681 TEST_F(ResolverBuiltinsValidationTest, Length_Float_Scalar) {
682   auto* builtin = Call("length", 1.0f);
683   WrapInFunction(builtin);
684 
685   EXPECT_TRUE(r()->Resolve()) << r()->error();
686 }
687 
TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec2)688 TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec2) {
689   auto* builtin = Call("length", vec2<f32>(1.0f, 1.0f));
690   WrapInFunction(builtin);
691 
692   EXPECT_TRUE(r()->Resolve()) << r()->error();
693 }
694 
TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec3)695 TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec3) {
696   auto* builtin = Call("length", vec3<f32>(1.0f, 1.0f, 1.0f));
697   WrapInFunction(builtin);
698 
699   EXPECT_TRUE(r()->Resolve()) << r()->error();
700 }
701 
TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec4)702 TEST_F(ResolverBuiltinsValidationTest, Length_Float_Vec4) {
703   auto* builtin = Call("length", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
704   WrapInFunction(builtin);
705 
706   EXPECT_TRUE(r()->Resolve()) << r()->error();
707 }
708 
TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Scalar)709 TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Scalar) {
710   auto* builtin = Call("distance", 1.0f, 1.0f);
711   WrapInFunction(builtin);
712 
713   EXPECT_TRUE(r()->Resolve()) << r()->error();
714 }
715 
TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec2)716 TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec2) {
717   auto* builtin =
718       Call("distance", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f));
719   WrapInFunction(builtin);
720 
721   EXPECT_TRUE(r()->Resolve()) << r()->error();
722 }
723 
TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec3)724 TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec3) {
725   auto* builtin = Call("distance", vec3<f32>(1.0f, 1.0f, 1.0f),
726                        vec3<f32>(1.0f, 1.0f, 1.0f));
727   WrapInFunction(builtin);
728 
729   EXPECT_TRUE(r()->Resolve()) << r()->error();
730 }
731 
TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec4)732 TEST_F(ResolverBuiltinsValidationTest, Distance_Float_Vec4) {
733   auto* builtin = Call("distance", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
734                        vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
735   WrapInFunction(builtin);
736 
737   EXPECT_TRUE(r()->Resolve()) << r()->error();
738 }
739 
TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat2x2)740 TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat2x2) {
741   auto* builtin = Call(
742       "determinant", mat2x2<f32>(vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f)));
743   WrapInFunction(builtin);
744 
745   EXPECT_TRUE(r()->Resolve()) << r()->error();
746 }
747 
TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat3x3)748 TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat3x3) {
749   auto* builtin = Call("determinant", mat3x3<f32>(vec3<f32>(1.0f, 1.0f, 1.0f),
750                                                   vec3<f32>(1.0f, 1.0f, 1.0f),
751                                                   vec3<f32>(1.0f, 1.0f, 1.0f)));
752   WrapInFunction(builtin);
753 
754   EXPECT_TRUE(r()->Resolve()) << r()->error();
755 }
756 
TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat4x4)757 TEST_F(ResolverBuiltinsValidationTest, Determinant_Mat4x4) {
758   auto* builtin =
759       Call("determinant", mat4x4<f32>(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
760                                       vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
761                                       vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
762                                       vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f)));
763   WrapInFunction(builtin);
764 
765   EXPECT_TRUE(r()->Resolve()) << r()->error();
766 }
767 
TEST_F(ResolverBuiltinsValidationTest, Frexp_Scalar)768 TEST_F(ResolverBuiltinsValidationTest, Frexp_Scalar) {
769   auto* builtin = Call("frexp", 1.0f);
770   WrapInFunction(builtin);
771 
772   EXPECT_TRUE(r()->Resolve()) << r()->error();
773   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
774   ASSERT_TRUE(res_ty != nullptr);
775   auto& members = res_ty->Members();
776   ASSERT_EQ(members.size(), 2u);
777   EXPECT_TRUE(members[0]->Type()->Is<sem::F32>());
778   EXPECT_TRUE(members[1]->Type()->Is<sem::I32>());
779 }
780 
TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec2)781 TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec2) {
782   auto* builtin = Call("frexp", vec2<f32>(1.0f, 1.0f));
783   WrapInFunction(builtin);
784 
785   EXPECT_TRUE(r()->Resolve()) << r()->error();
786   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
787   ASSERT_TRUE(res_ty != nullptr);
788   auto& members = res_ty->Members();
789   ASSERT_EQ(members.size(), 2u);
790   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
791   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
792   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 2u);
793   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
794   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 2u);
795   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
796 }
797 
TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec3)798 TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec3) {
799   auto* builtin = Call("frexp", vec3<f32>(1.0f, 1.0f, 1.0f));
800   WrapInFunction(builtin);
801 
802   EXPECT_TRUE(r()->Resolve()) << r()->error();
803   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
804   ASSERT_TRUE(res_ty != nullptr);
805   auto& members = res_ty->Members();
806   ASSERT_EQ(members.size(), 2u);
807   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
808   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
809   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 3u);
810   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
811   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 3u);
812   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
813 }
814 
TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec4)815 TEST_F(ResolverBuiltinsValidationTest, Frexp_Vec4) {
816   auto* builtin = Call("frexp", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
817   WrapInFunction(builtin);
818 
819   EXPECT_TRUE(r()->Resolve()) << r()->error();
820   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
821   ASSERT_TRUE(res_ty != nullptr);
822   auto& members = res_ty->Members();
823   ASSERT_EQ(members.size(), 2u);
824   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
825   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
826   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 4u);
827   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
828   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 4u);
829   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
830 }
831 
TEST_F(ResolverBuiltinsValidationTest, Modf_Scalar)832 TEST_F(ResolverBuiltinsValidationTest, Modf_Scalar) {
833   auto* builtin = Call("modf", 1.0f);
834   WrapInFunction(builtin);
835 
836   EXPECT_TRUE(r()->Resolve()) << r()->error();
837   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
838   ASSERT_TRUE(res_ty != nullptr);
839   auto& members = res_ty->Members();
840   ASSERT_EQ(members.size(), 2u);
841   EXPECT_TRUE(members[0]->Type()->Is<sem::F32>());
842   EXPECT_TRUE(members[1]->Type()->Is<sem::F32>());
843 }
844 
TEST_F(ResolverBuiltinsValidationTest, Modf_Vec2)845 TEST_F(ResolverBuiltinsValidationTest, Modf_Vec2) {
846   auto* builtin = Call("modf", vec2<f32>(1.0f, 1.0f));
847   WrapInFunction(builtin);
848 
849   EXPECT_TRUE(r()->Resolve()) << r()->error();
850   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
851   ASSERT_TRUE(res_ty != nullptr);
852   auto& members = res_ty->Members();
853   ASSERT_EQ(members.size(), 2u);
854   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
855   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
856   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 2u);
857   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
858   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 2u);
859   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
860 }
861 
TEST_F(ResolverBuiltinsValidationTest, Modf_Vec3)862 TEST_F(ResolverBuiltinsValidationTest, Modf_Vec3) {
863   auto* builtin = Call("modf", vec3<f32>(1.0f, 1.0f, 1.0f));
864   WrapInFunction(builtin);
865 
866   EXPECT_TRUE(r()->Resolve()) << r()->error();
867   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
868   ASSERT_TRUE(res_ty != nullptr);
869   auto& members = res_ty->Members();
870   ASSERT_EQ(members.size(), 2u);
871   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
872   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
873   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 3u);
874   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
875   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 3u);
876   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
877 }
878 
TEST_F(ResolverBuiltinsValidationTest, Modf_Vec4)879 TEST_F(ResolverBuiltinsValidationTest, Modf_Vec4) {
880   auto* builtin = Call("modf", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
881   WrapInFunction(builtin);
882 
883   EXPECT_TRUE(r()->Resolve()) << r()->error();
884   auto* res_ty = TypeOf(builtin)->As<sem::Struct>();
885   ASSERT_TRUE(res_ty != nullptr);
886   auto& members = res_ty->Members();
887   ASSERT_EQ(members.size(), 2u);
888   ASSERT_TRUE(members[0]->Type()->Is<sem::Vector>());
889   ASSERT_TRUE(members[1]->Type()->Is<sem::Vector>());
890   EXPECT_EQ(members[0]->Type()->As<sem::Vector>()->Width(), 4u);
891   EXPECT_TRUE(members[0]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
892   EXPECT_EQ(members[1]->Type()->As<sem::Vector>()->Width(), 4u);
893   EXPECT_TRUE(members[1]->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
894 }
895 
TEST_F(ResolverBuiltinsValidationTest, Cross_Float_Vec3)896 TEST_F(ResolverBuiltinsValidationTest, Cross_Float_Vec3) {
897   auto* builtin =
898       Call("cross", vec3<f32>(1.0f, 1.0f, 1.0f), vec3<f32>(1.0f, 1.0f, 1.0f));
899   WrapInFunction(builtin);
900 
901   EXPECT_TRUE(r()->Resolve()) << r()->error();
902 }
903 
TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec2)904 TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec2) {
905   auto* builtin = Call("dot", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f));
906   WrapInFunction(builtin);
907 
908   EXPECT_TRUE(r()->Resolve()) << r()->error();
909 }
910 
TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec3)911 TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec3) {
912   auto* builtin =
913       Call("dot", vec3<f32>(1.0f, 1.0f, 1.0f), vec3<f32>(1.0f, 1.0f, 1.0f));
914   WrapInFunction(builtin);
915 
916   EXPECT_TRUE(r()->Resolve()) << r()->error();
917 }
918 
TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec4)919 TEST_F(ResolverBuiltinsValidationTest, Dot_Float_Vec4) {
920   auto* builtin = Call("dot", vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f),
921                        vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
922   WrapInFunction(builtin);
923 
924   EXPECT_TRUE(r()->Resolve()) << r()->error();
925 }
926 
TEST_F(ResolverBuiltinsValidationTest, Select_Float_Scalar)927 TEST_F(ResolverBuiltinsValidationTest, Select_Float_Scalar) {
928   auto* builtin = Call("select", Expr(1.0f), Expr(1.0f), Expr(true));
929   WrapInFunction(builtin);
930 
931   EXPECT_TRUE(r()->Resolve()) << r()->error();
932 }
933 
TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Scalar)934 TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Scalar) {
935   auto* builtin = Call("select", Expr(1), Expr(1), Expr(true));
936   WrapInFunction(builtin);
937 
938   EXPECT_TRUE(r()->Resolve()) << r()->error();
939 }
940 
TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Scalar)941 TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Scalar) {
942   auto* builtin = Call("select", Expr(true), Expr(true), Expr(true));
943   WrapInFunction(builtin);
944 
945   EXPECT_TRUE(r()->Resolve()) << r()->error();
946 }
947 
TEST_F(ResolverBuiltinsValidationTest, Select_Float_Vec2)948 TEST_F(ResolverBuiltinsValidationTest, Select_Float_Vec2) {
949   auto* builtin = Call("select", vec2<f32>(1.0f, 1.0f), vec2<f32>(1.0f, 1.0f),
950                        vec2<bool>(true, true));
951   WrapInFunction(builtin);
952 
953   EXPECT_TRUE(r()->Resolve()) << r()->error();
954 }
955 
TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Vec2)956 TEST_F(ResolverBuiltinsValidationTest, Select_Integer_Vec2) {
957   auto* builtin =
958       Call("select", vec2<int>(1, 1), vec2<int>(1, 1), vec2<bool>(true, true));
959   WrapInFunction(builtin);
960 
961   EXPECT_TRUE(r()->Resolve()) << r()->error();
962 }
963 
TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Vec2)964 TEST_F(ResolverBuiltinsValidationTest, Select_Boolean_Vec2) {
965   auto* builtin = Call("select", vec2<bool>(true, true), vec2<bool>(true, true),
966                        vec2<bool>(true, true));
967   WrapInFunction(builtin);
968 
969   EXPECT_TRUE(r()->Resolve()) << r()->error();
970 }
971 
972 template <typename T>
973 class ResolverBuiltinsValidationTestWithParams
974     : public resolver::TestHelper,
975       public testing::TestWithParam<T> {};
976 
977 using FloatAllMatching =
978     ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
979 
TEST_P(FloatAllMatching, Scalar)980 TEST_P(FloatAllMatching, Scalar) {
981   std::string name = std::get<0>(GetParam());
982   uint32_t num_params = std::get<1>(GetParam());
983 
984   ast::ExpressionList params;
985   for (uint32_t i = 0; i < num_params; ++i) {
986     params.push_back(Expr(1.0f));
987   }
988   auto* builtin = Call(name, params);
989   Func("func", {}, ty.void_(), {CallStmt(builtin)},
990        {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
991 
992   EXPECT_TRUE(r()->Resolve()) << r()->error();
993   EXPECT_TRUE(TypeOf(builtin)->Is<sem::F32>());
994 }
995 
TEST_P(FloatAllMatching, Vec2)996 TEST_P(FloatAllMatching, Vec2) {
997   std::string name = std::get<0>(GetParam());
998   uint32_t num_params = std::get<1>(GetParam());
999 
1000   ast::ExpressionList params;
1001   for (uint32_t i = 0; i < num_params; ++i) {
1002     params.push_back(vec2<f32>(1.0f, 1.0f));
1003   }
1004   auto* builtin = Call(name, params);
1005   Func("func", {}, ty.void_(), {CallStmt(builtin)},
1006        {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
1007 
1008   EXPECT_TRUE(r()->Resolve()) << r()->error();
1009   EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
1010 }
1011 
TEST_P(FloatAllMatching, Vec3)1012 TEST_P(FloatAllMatching, Vec3) {
1013   std::string name = std::get<0>(GetParam());
1014   uint32_t num_params = std::get<1>(GetParam());
1015 
1016   ast::ExpressionList params;
1017   for (uint32_t i = 0; i < num_params; ++i) {
1018     params.push_back(vec3<f32>(1.0f, 1.0f, 1.0f));
1019   }
1020   auto* builtin = Call(name, params);
1021   Func("func", {}, ty.void_(), {CallStmt(builtin)},
1022        {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
1023 
1024   EXPECT_TRUE(r()->Resolve()) << r()->error();
1025   EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
1026 }
1027 
TEST_P(FloatAllMatching, Vec4)1028 TEST_P(FloatAllMatching, Vec4) {
1029   std::string name = std::get<0>(GetParam());
1030   uint32_t num_params = std::get<1>(GetParam());
1031 
1032   ast::ExpressionList params;
1033   for (uint32_t i = 0; i < num_params; ++i) {
1034     params.push_back(vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
1035   }
1036   auto* builtin = Call(name, params);
1037   Func("func", {}, ty.void_(), {CallStmt(builtin)},
1038        {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
1039 
1040   EXPECT_TRUE(r()->Resolve()) << r()->error();
1041   EXPECT_TRUE(TypeOf(builtin)->is_float_vector());
1042 }
1043 
1044 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
1045                          FloatAllMatching,
1046                          ::testing::Values(std::make_tuple("abs", 1),
1047                                            std::make_tuple("acos", 1),
1048                                            std::make_tuple("asin", 1),
1049                                            std::make_tuple("atan", 1),
1050                                            std::make_tuple("atan2", 2),
1051                                            std::make_tuple("ceil", 1),
1052                                            std::make_tuple("clamp", 3),
1053                                            std::make_tuple("cos", 1),
1054                                            std::make_tuple("cosh", 1),
1055                                            std::make_tuple("dpdx", 1),
1056                                            std::make_tuple("dpdxCoarse", 1),
1057                                            std::make_tuple("dpdxFine", 1),
1058                                            std::make_tuple("dpdy", 1),
1059                                            std::make_tuple("dpdyCoarse", 1),
1060                                            std::make_tuple("dpdyFine", 1),
1061                                            std::make_tuple("exp", 1),
1062                                            std::make_tuple("exp2", 1),
1063                                            std::make_tuple("floor", 1),
1064                                            std::make_tuple("fma", 3),
1065                                            std::make_tuple("fract", 1),
1066                                            std::make_tuple("fwidth", 1),
1067                                            std::make_tuple("fwidthCoarse", 1),
1068                                            std::make_tuple("fwidthFine", 1),
1069                                            std::make_tuple("inverseSqrt", 1),
1070                                            std::make_tuple("log", 1),
1071                                            std::make_tuple("log2", 1),
1072                                            std::make_tuple("max", 2),
1073                                            std::make_tuple("min", 2),
1074                                            std::make_tuple("mix", 3),
1075                                            std::make_tuple("pow", 2),
1076                                            std::make_tuple("round", 1),
1077                                            std::make_tuple("sign", 1),
1078                                            std::make_tuple("sin", 1),
1079                                            std::make_tuple("sinh", 1),
1080                                            std::make_tuple("smoothStep", 3),
1081                                            std::make_tuple("sqrt", 1),
1082                                            std::make_tuple("step", 2),
1083                                            std::make_tuple("tan", 1),
1084                                            std::make_tuple("tanh", 1),
1085                                            std::make_tuple("trunc", 1)));
1086 
1087 using IntegerAllMatching =
1088     ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
1089 
TEST_P(IntegerAllMatching, ScalarUnsigned)1090 TEST_P(IntegerAllMatching, ScalarUnsigned) {
1091   std::string name = std::get<0>(GetParam());
1092   uint32_t num_params = std::get<1>(GetParam());
1093 
1094   ast::ExpressionList params;
1095   for (uint32_t i = 0; i < num_params; ++i) {
1096     params.push_back(Construct<uint32_t>(1));
1097   }
1098   auto* builtin = Call(name, params);
1099   WrapInFunction(builtin);
1100 
1101   EXPECT_TRUE(r()->Resolve()) << r()->error();
1102   EXPECT_TRUE(TypeOf(builtin)->Is<sem::U32>());
1103 }
1104 
TEST_P(IntegerAllMatching, Vec2Unsigned)1105 TEST_P(IntegerAllMatching, Vec2Unsigned) {
1106   std::string name = std::get<0>(GetParam());
1107   uint32_t num_params = std::get<1>(GetParam());
1108 
1109   ast::ExpressionList params;
1110   for (uint32_t i = 0; i < num_params; ++i) {
1111     params.push_back(vec2<uint32_t>(1u, 1u));
1112   }
1113   auto* builtin = Call(name, params);
1114   WrapInFunction(builtin);
1115 
1116   EXPECT_TRUE(r()->Resolve()) << r()->error();
1117   EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
1118 }
1119 
TEST_P(IntegerAllMatching, Vec3Unsigned)1120 TEST_P(IntegerAllMatching, Vec3Unsigned) {
1121   std::string name = std::get<0>(GetParam());
1122   uint32_t num_params = std::get<1>(GetParam());
1123 
1124   ast::ExpressionList params;
1125   for (uint32_t i = 0; i < num_params; ++i) {
1126     params.push_back(vec3<uint32_t>(1u, 1u, 1u));
1127   }
1128   auto* builtin = Call(name, params);
1129   WrapInFunction(builtin);
1130 
1131   EXPECT_TRUE(r()->Resolve()) << r()->error();
1132   EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
1133 }
1134 
TEST_P(IntegerAllMatching, Vec4Unsigned)1135 TEST_P(IntegerAllMatching, Vec4Unsigned) {
1136   std::string name = std::get<0>(GetParam());
1137   uint32_t num_params = std::get<1>(GetParam());
1138 
1139   ast::ExpressionList params;
1140   for (uint32_t i = 0; i < num_params; ++i) {
1141     params.push_back(vec4<uint32_t>(1u, 1u, 1u, 1u));
1142   }
1143   auto* builtin = Call(name, params);
1144   WrapInFunction(builtin);
1145 
1146   EXPECT_TRUE(r()->Resolve()) << r()->error();
1147   EXPECT_TRUE(TypeOf(builtin)->is_unsigned_integer_vector());
1148 }
1149 
TEST_P(IntegerAllMatching, ScalarSigned)1150 TEST_P(IntegerAllMatching, ScalarSigned) {
1151   std::string name = std::get<0>(GetParam());
1152   uint32_t num_params = std::get<1>(GetParam());
1153 
1154   ast::ExpressionList params;
1155   for (uint32_t i = 0; i < num_params; ++i) {
1156     params.push_back(Construct<int32_t>(1));
1157   }
1158   auto* builtin = Call(name, params);
1159   WrapInFunction(builtin);
1160 
1161   EXPECT_TRUE(r()->Resolve()) << r()->error();
1162   EXPECT_TRUE(TypeOf(builtin)->Is<sem::I32>());
1163 }
1164 
TEST_P(IntegerAllMatching, Vec2Signed)1165 TEST_P(IntegerAllMatching, Vec2Signed) {
1166   std::string name = std::get<0>(GetParam());
1167   uint32_t num_params = std::get<1>(GetParam());
1168 
1169   ast::ExpressionList params;
1170   for (uint32_t i = 0; i < num_params; ++i) {
1171     params.push_back(vec2<int32_t>(1, 1));
1172   }
1173   auto* builtin = Call(name, params);
1174   WrapInFunction(builtin);
1175 
1176   EXPECT_TRUE(r()->Resolve()) << r()->error();
1177   EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
1178 }
1179 
TEST_P(IntegerAllMatching, Vec3Signed)1180 TEST_P(IntegerAllMatching, Vec3Signed) {
1181   std::string name = std::get<0>(GetParam());
1182   uint32_t num_params = std::get<1>(GetParam());
1183 
1184   ast::ExpressionList params;
1185   for (uint32_t i = 0; i < num_params; ++i) {
1186     params.push_back(vec3<int32_t>(1, 1, 1));
1187   }
1188   auto* builtin = Call(name, params);
1189   WrapInFunction(builtin);
1190 
1191   EXPECT_TRUE(r()->Resolve()) << r()->error();
1192   EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
1193 }
1194 
TEST_P(IntegerAllMatching, Vec4Signed)1195 TEST_P(IntegerAllMatching, Vec4Signed) {
1196   std::string name = std::get<0>(GetParam());
1197   uint32_t num_params = std::get<1>(GetParam());
1198 
1199   ast::ExpressionList params;
1200   for (uint32_t i = 0; i < num_params; ++i) {
1201     params.push_back(vec4<int32_t>(1, 1, 1, 1));
1202   }
1203   auto* builtin = Call(name, params);
1204   WrapInFunction(builtin);
1205 
1206   EXPECT_TRUE(r()->Resolve()) << r()->error();
1207   EXPECT_TRUE(TypeOf(builtin)->is_signed_integer_vector());
1208 }
1209 
1210 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
1211                          IntegerAllMatching,
1212                          ::testing::Values(std::make_tuple("abs", 1),
1213                                            std::make_tuple("clamp", 3),
1214                                            std::make_tuple("countOneBits", 1),
1215                                            std::make_tuple("max", 2),
1216                                            std::make_tuple("min", 2),
1217                                            std::make_tuple("reverseBits", 1)));
1218 
1219 using BooleanVectorInput =
1220     ResolverBuiltinsValidationTestWithParams<std::tuple<std::string, uint32_t>>;
1221 
TEST_P(BooleanVectorInput, Vec2)1222 TEST_P(BooleanVectorInput, Vec2) {
1223   std::string name = std::get<0>(GetParam());
1224   uint32_t num_params = std::get<1>(GetParam());
1225 
1226   ast::ExpressionList params;
1227   for (uint32_t i = 0; i < num_params; ++i) {
1228     params.push_back(vec2<bool>(true, true));
1229   }
1230   auto* builtin = Call(name, params);
1231   WrapInFunction(builtin);
1232 
1233   EXPECT_TRUE(r()->Resolve()) << r()->error();
1234 }
1235 
TEST_P(BooleanVectorInput, Vec3)1236 TEST_P(BooleanVectorInput, Vec3) {
1237   std::string name = std::get<0>(GetParam());
1238   uint32_t num_params = std::get<1>(GetParam());
1239 
1240   ast::ExpressionList params;
1241   for (uint32_t i = 0; i < num_params; ++i) {
1242     params.push_back(vec3<bool>(true, true, true));
1243   }
1244   auto* builtin = Call(name, params);
1245   WrapInFunction(builtin);
1246 
1247   EXPECT_TRUE(r()->Resolve()) << r()->error();
1248 }
1249 
TEST_P(BooleanVectorInput, Vec4)1250 TEST_P(BooleanVectorInput, Vec4) {
1251   std::string name = std::get<0>(GetParam());
1252   uint32_t num_params = std::get<1>(GetParam());
1253 
1254   ast::ExpressionList params;
1255   for (uint32_t i = 0; i < num_params; ++i) {
1256     params.push_back(vec4<bool>(true, true, true, true));
1257   }
1258   auto* builtin = Call(name, params);
1259   WrapInFunction(builtin);
1260 
1261   EXPECT_TRUE(r()->Resolve()) << r()->error();
1262 }
1263 
1264 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
1265                          BooleanVectorInput,
1266                          ::testing::Values(std::make_tuple("all", 1),
1267                                            std::make_tuple("any", 1)));
1268 
1269 using DataPacking4x8 = ResolverBuiltinsValidationTestWithParams<std::string>;
1270 
TEST_P(DataPacking4x8, Float_Vec4)1271 TEST_P(DataPacking4x8, Float_Vec4) {
1272   auto name = GetParam();
1273   auto* builtin = Call(name, vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f));
1274   WrapInFunction(builtin);
1275   EXPECT_TRUE(r()->Resolve()) << r()->error();
1276 }
1277 
1278 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
1279                          DataPacking4x8,
1280                          ::testing::Values("pack4x8snorm", "pack4x8unorm"));
1281 
1282 using DataPacking2x16 = ResolverBuiltinsValidationTestWithParams<std::string>;
1283 
TEST_P(DataPacking2x16, Float_Vec2)1284 TEST_P(DataPacking2x16, Float_Vec2) {
1285   auto name = GetParam();
1286   auto* builtin = Call(name, vec2<f32>(1.0f, 1.0f));
1287   WrapInFunction(builtin);
1288   EXPECT_TRUE(r()->Resolve()) << r()->error();
1289 }
1290 
1291 INSTANTIATE_TEST_SUITE_P(ResolverBuiltinsValidationTest,
1292                          DataPacking2x16,
1293                          ::testing::Values("pack2x16snorm",
1294                                            "pack2x16unorm",
1295                                            "pack2x16float"));
1296 
1297 }  // namespace
1298 }  // namespace resolver
1299 }  // namespace tint
1300