1// Copyright (c) 2022 Advanced Micro Devices, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "test/opt/pass_fixture.h"
16#include "test/opt/pass_utils.h"
17
18namespace spvtools {
19namespace opt {
20namespace {
21
22using FixFuncCallArgumentsTest = PassTest<::testing::Test>;
23TEST_F(FixFuncCallArgumentsTest, Simple) {
24  const std::string text = R"(
25;
26; CHECK: [[v0:%\w+]] = OpVariable %_ptr_Function_float Function
27; CHECK: [[v1:%\w+]] = OpVariable %_ptr_Function_float Function
28; CHECK: [[v2:%\w+]] = OpVariable %_ptr_Function_T Function
29; CHECK: [[ac0:%\w+]] = OpAccessChain %_ptr_Function_float %t %int_0
30; CHECK: [[ac1:%\w+]] = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
31; CHECK: [[ld0:%\w+]] = OpLoad %float [[ac0]]
32; CHECK:                OpStore [[v1]] [[ld0]]
33; CHECK: [[ld1:%\w+]] = OpLoad %float [[ac1]]
34; CHECK:                OpStore [[v0]] [[ld1]]
35; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[v1]] [[v0]]
36; CHECK: [[ld2:%\w+]] = OpLoad %float [[v0]]
37; CHECK: OpStore [[ac1]] [[ld2]]
38; CHECK: [[ld3:%\w+]] = OpLoad %float [[v1]]
39; CHECK: OpStore [[ac0]] [[ld3]]
40;
41OpCapability Shader
42OpCapability Linkage
43OpMemoryModel Logical GLSL450
44OpSource HLSL 630
45OpName %type_RWStructuredBuffer_float "type.RWStructuredBuffer.float"
46OpName %r1 "r1"
47OpName %type_ACSBuffer_counter "type.ACSBuffer.counter"
48OpMemberName %type_ACSBuffer_counter 0 "counter"
49OpName %counter_var_r1 "counter.var.r1"
50OpName %main "main"
51OpName %bb_entry "bb.entry"
52OpName %T "T"
53OpMemberName %T 0 "t0"
54OpName %t "t"
55OpName %fn "fn"
56OpName %p0 "p0"
57OpName %p2 "p2"
58OpName %bb_entry_0 "bb.entry"
59OpDecorate %main LinkageAttributes "main" Export
60OpDecorate %r1 DescriptorSet 0
61OpDecorate %r1 Binding 0
62OpDecorate %counter_var_r1 DescriptorSet 0
63OpDecorate %counter_var_r1 Binding 1
64OpDecorate %_runtimearr_float ArrayStride 4
65OpMemberDecorate %type_RWStructuredBuffer_float 0 Offset 0
66OpDecorate %type_RWStructuredBuffer_float BufferBlock
67OpMemberDecorate %type_ACSBuffer_counter 0 Offset 0
68OpDecorate %type_ACSBuffer_counter BufferBlock
69%int = OpTypeInt 32 1
70%int_0 = OpConstant %int 0
71%uint = OpTypeInt 32 0
72%uint_0 = OpConstant %uint 0
73%int_1 = OpConstant %int 1
74%float = OpTypeFloat 32
75%_runtimearr_float = OpTypeRuntimeArray %float
76%type_RWStructuredBuffer_float = OpTypeStruct %_runtimearr_float
77%_ptr_Uniform_type_RWStructuredBuffer_float = OpTypePointer Uniform %type_RWStructuredBuffer_float
78%type_ACSBuffer_counter = OpTypeStruct %int
79%_ptr_Uniform_type_ACSBuffer_counter = OpTypePointer Uniform %type_ACSBuffer_counter
80%15 = OpTypeFunction %int
81%T = OpTypeStruct %float
82%_ptr_Function_T = OpTypePointer Function %T
83%_ptr_Function_float = OpTypePointer Function %float
84%_ptr_Uniform_float = OpTypePointer Uniform %float
85%void = OpTypeVoid
86%27 = OpTypeFunction %void %_ptr_Function_float %_ptr_Function_float
87%r1 = OpVariable %_ptr_Uniform_type_RWStructuredBuffer_float Uniform
88%counter_var_r1 = OpVariable %_ptr_Uniform_type_ACSBuffer_counter Uniform
89%main = OpFunction %int None %15
90%bb_entry = OpLabel
91%t = OpVariable %_ptr_Function_T Function
92%21 = OpAccessChain %_ptr_Function_float %t %int_0
93%23 = OpAccessChain %_ptr_Uniform_float %r1 %int_0 %uint_0
94%25 = OpFunctionCall %void %fn %21 %23
95OpReturnValue %int_1
96OpFunctionEnd
97%fn = OpFunction %void DontInline %27
98%p0 = OpFunctionParameter %_ptr_Function_float
99%p2 = OpFunctionParameter %_ptr_Function_float
100%bb_entry_0 = OpLabel
101OpReturn
102OpFunctionEnd
103)";
104
105  SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, true);
106}
107
108TEST_F(FixFuncCallArgumentsTest, NotAccessChainInput) {
109  const std::string text = R"(
110;
111; CHECK: [[o:%\w+]] = OpCopyObject %_ptr_Function_float %t
112; CHECK: [[func:%\w+]] = OpFunctionCall %void %fn [[o]]
113;
114OpCapability Shader
115OpCapability Linkage
116OpMemoryModel Logical GLSL450
117OpSource HLSL 630
118OpName %main "main"
119OpName %bb_entry "bb.entry"
120OpName %t "t"
121OpName %fn "fn"
122OpName %p0 "p0"
123OpName %bb_entry_0 "bb.entry"
124OpDecorate %main LinkageAttributes "main" Export
125%int = OpTypeInt 32 1
126%int_1 = OpConstant %int 1
127%4 = OpTypeFunction %int
128%float = OpTypeFloat 32
129%_ptr_Function_float = OpTypePointer Function %float
130%void = OpTypeVoid
131%12 = OpTypeFunction %void %_ptr_Function_float
132%main = OpFunction %int None %4
133%bb_entry = OpLabel
134%t = OpVariable %_ptr_Function_float Function
135%t1 = OpCopyObject %_ptr_Function_float %t
136%10 = OpFunctionCall %void %fn %t1
137OpReturnValue %int_1
138OpFunctionEnd
139%fn = OpFunction %void DontInline %12
140%p0 = OpFunctionParameter %_ptr_Function_float
141%bb_entry_0 = OpLabel
142OpReturn
143OpFunctionEnd
144)";
145
146  SinglePassRunAndMatch<FixFuncCallArgumentsPass>(text, false);
147}
148
149}  // namespace
150}  // namespace opt
151}  // namespace spvtools