A game about forced loneliness, made by TACStudios
1#ifndef THREADING_EMU_IMPL
2#define THREADING_EMU_IMPL
3
4// If the user didn't specify a wave size, we assume that their code is "wave size independent" and that they don't
5// care which size is actually used. In this case, we automatically select an arbitrary size for them since the
6// emulation logic depends on having *some* known size.
7#ifndef THREADING_WAVE_SIZE
8#define THREADING_WAVE_SIZE 32
9#endif
10
11namespace Threading
12{
13 // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those.
14 // Support for vector types is currently not there but can be added as needed (and this comment removed).
15 groupshared uint g_Scratch[THREADING_BLOCK_SIZE];
16
17#define EMULATED_WAVE_REDUCE(TYPE, OP) \
18 GroupMemoryBarrierWithGroupSync(); \
19 g_Scratch[indexG] = asuint(v); \
20 GroupMemoryBarrierWithGroupSync(); \
21 [unroll] \
22 for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \
23 { \
24 if (indexL < s) \
25 g_Scratch[indexG] = asuint(as##TYPE(g_Scratch[indexG]) OP as##TYPE(g_Scratch[indexG + s])); \
26 GroupMemoryBarrierWithGroupSync(); \
27 } \
28 return as##TYPE(g_Scratch[offset]); \
29
30#define EMULATED_WAVE_REDUCE_CMP(TYPE, OP) \
31 GroupMemoryBarrierWithGroupSync(); \
32 g_Scratch[indexG] = asuint(v); \
33 GroupMemoryBarrierWithGroupSync(); \
34 [unroll] \
35 for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \
36 { \
37 if (indexL < s) \
38 g_Scratch[indexG] = asuint(OP(as##TYPE(g_Scratch[indexG]), as##TYPE(g_Scratch[indexG + s]))); \
39 GroupMemoryBarrierWithGroupSync(); \
40 } \
41 return as##TYPE(g_Scratch[offset]); \
42
43#define EMULATED_WAVE_PREFIX(TYPE, OP, FILL_VALUE) \
44 GroupMemoryBarrierWithGroupSync(); \
45 g_Scratch[indexG] = asuint(v); \
46 GroupMemoryBarrierWithGroupSync(); \
47 [unroll] \
48 for (uint s = 1u; s < THREADING_WAVE_SIZE; s <<= 1u) \
49 { \
50 TYPE nv = FILL_VALUE; \
51 if (indexL >= s) \
52 { \
53 nv = as##TYPE(g_Scratch[indexG - s]); \
54 } \
55 nv = as##TYPE(g_Scratch[indexG]) OP nv; \
56 GroupMemoryBarrierWithGroupSync(); \
57 g_Scratch[indexG] = asuint(nv); \
58 GroupMemoryBarrierWithGroupSync(); \
59 } \
60 TYPE result = FILL_VALUE; \
61 if (indexL > 0u) \
62 result = as##TYPE(g_Scratch[indexG - 1]); \
63 return result; \
64
65 uint Wave::GetIndex() { return indexW; }
66
67 void Wave::Init(uint groupIndex)
68 {
69 indexG = groupIndex;
70 indexW = indexG / THREADING_WAVE_SIZE;
71 indexL = indexG & (THREADING_WAVE_SIZE - 1);
72 offset = indexW * THREADING_WAVE_SIZE;
73 }
74
75 // WARNING:
76 // These emulated functions do not emulate the execution mask.
77 // So they WILL produce incorrect results if you have divergent lanes.
78
79 #define DEFINE_API_FOR_TYPE(TYPE) \
80 bool Wave::AllEqual(TYPE v) { return AllTrue(ReadLaneFirst(v) == v); } \
81 TYPE Wave::Product(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, *) } \
82 TYPE Wave::Sum(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, +) } \
83 TYPE Wave::Max(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, max) } \
84 TYPE Wave::Min(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, min) } \
85 TYPE Wave::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
86 TYPE Wave::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
87 TYPE Wave::PrefixSum (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, +, (TYPE)0) } \
88 TYPE Wave::PrefixProduct (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, *, (TYPE)1) } \
89 TYPE Wave::ReadLaneAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[indexG] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[offset + i]); } \
90 TYPE Wave::ReadLaneFirst(TYPE v) { return ReadLaneAt(v, 0u); } \
91
92 // Currently just support scalars.
93 DEFINE_API_FOR_TYPE(uint)
94 DEFINE_API_FOR_TYPE(int)
95 DEFINE_API_FOR_TYPE(float)
96
97 // The following emulated functions need only be declared once.
98 uint Wave::GetLaneCount() { return THREADING_WAVE_SIZE; }
99 uint Wave::GetLaneIndex() { return indexL; }
100 bool Wave::IsFirstLane() { return indexL == 0u; }
101 bool Wave::AllTrue(bool v) { return And(v) != 0u; }
102 bool Wave::AnyTrue(bool v) { return Or (v) != 0u; }
103 uint Wave::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
104 uint Wave::And(uint v) { EMULATED_WAVE_REDUCE(uint, &) }
105 uint Wave::Or (uint v) { EMULATED_WAVE_REDUCE(uint, |) }
106 uint Wave::Xor(uint v) { EMULATED_WAVE_REDUCE(uint, ^) }
107
108 uint4 Wave::Ballot(bool v)
109 {
110 uint indexDw = indexL % 32u;
111 uint offsetDw = (indexL / 32u) * 32u;
112 uint indexScratch = offset + offsetDw + indexDw;
113
114 GroupMemoryBarrierWithGroupSync();
115
116 g_Scratch[indexG] = v << indexDw;
117
118 GroupMemoryBarrierWithGroupSync();
119
120 [unroll]
121 for (uint s = min(THREADING_WAVE_SIZE / 2u, 16u); s > 0u; s >>= 1u)
122 {
123 if (indexDw < s)
124 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
125
126 GroupMemoryBarrierWithGroupSync();
127 }
128
129 uint4 result = uint4(g_Scratch[offset], 0, 0, 0);
130
131#if THREADING_WAVE_SIZE > 32
132 result.y = g_Scratch[offset + 32];
133#endif
134
135#if THREADING_WAVE_SIZE > 64
136 result.z = g_Scratch[offset + 64];
137#endif
138
139#if THREADING_WAVE_SIZE > 96
140 result.w = g_Scratch[offset + 96];
141#endif
142
143 return result;
144 }
145
146 uint Wave::CountBits(bool v)
147 {
148 uint4 ballot = Ballot(v);
149
150 uint result = countbits(ballot.x);
151
152#if THREADING_WAVE_SIZE > 32
153 result += countbits(ballot.y);
154#endif
155
156#if THREADING_WAVE_SIZE > 64
157 result += countbits(ballot.z);
158#endif
159
160#if THREADING_WAVE_SIZE > 96
161 result += countbits(ballot.w);
162#endif
163
164 return result;
165 }
166
167#define EMULATED_GROUP_REDUCE(TYPE, OP) \
168 GroupMemoryBarrierWithGroupSync(); \
169 g_Scratch[groupIndex] = asuint(v); \
170 GroupMemoryBarrierWithGroupSync(); \
171 [unroll] \
172 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
173 { \
174 if (groupIndex < s) \
175 g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \
176 GroupMemoryBarrierWithGroupSync(); \
177 } \
178 return as##TYPE(g_Scratch[0]); \
179
180#define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \
181 GroupMemoryBarrierWithGroupSync(); \
182 g_Scratch[groupIndex] = asuint(v); \
183 GroupMemoryBarrierWithGroupSync(); \
184 [unroll] \
185 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \
186 { \
187 if (groupIndex < s) \
188 g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \
189 GroupMemoryBarrierWithGroupSync(); \
190 } \
191 return as##TYPE(g_Scratch[0]); \
192
193#define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \
194 GroupMemoryBarrierWithGroupSync(); \
195 g_Scratch[groupIndex] = asuint(v); \
196 GroupMemoryBarrierWithGroupSync(); \
197 [unroll] \
198 for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \
199 { \
200 TYPE nv = FILL_VALUE; \
201 if (groupIndex >= s) \
202 { \
203 nv = as##TYPE(g_Scratch[groupIndex - s]); \
204 } \
205 nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \
206 GroupMemoryBarrierWithGroupSync(); \
207 g_Scratch[groupIndex] = asuint(nv); \
208 GroupMemoryBarrierWithGroupSync(); \
209 } \
210 TYPE result = FILL_VALUE; \
211 if (groupIndex > 0u) \
212 result = as##TYPE(g_Scratch[groupIndex - 1]); \
213 return result; \
214
215 uint Group::GetWaveCount()
216 {
217 return THREADING_BLOCK_SIZE / THREADING_WAVE_SIZE;
218 }
219
220 #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \
221 bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \
222 TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \
223 TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \
224 TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \
225 TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \
226 TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \
227 TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \
228 TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \
229 TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \
230 TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \
231 TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \
232 TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \
233
234 // Currently just support scalars.
235 DEFINE_API_FOR_TYPE_GROUP(uint)
236 DEFINE_API_FOR_TYPE_GROUP(int)
237 DEFINE_API_FOR_TYPE_GROUP(float)
238
239 // The following emulated functions need only be declared once.
240 uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; }
241 uint Group::GetThreadIndex() { return groupIndex; }
242 bool Group::IsFirstThread() { return groupIndex == 0u; }
243 bool Group::AllTrue(bool v) { return And(v) != 0u; }
244 bool Group::AnyTrue(bool v) { return Or (v) != 0u; }
245 uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); }
246 uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) }
247 uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) }
248 uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) }
249
250 GroupBallot Group::Ballot(bool v)
251 {
252 uint indexDw = groupIndex % 32u;
253 uint offsetDw = (groupIndex / 32u) * 32u;
254 uint indexScratch = offsetDw + indexDw;
255
256 GroupMemoryBarrierWithGroupSync();
257
258 g_Scratch[groupIndex] = v << indexDw;
259
260 GroupMemoryBarrierWithGroupSync();
261
262 [unroll]
263 for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u)
264 {
265 if (indexDw < s)
266 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s];
267
268 GroupMemoryBarrierWithGroupSync();
269 }
270
271 GroupBallot ballot = (GroupBallot)0;
272
273 // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference
274 [unroll]
275 for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex)
276 {
277 ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32];
278 }
279
280 return ballot;
281 }
282
283 uint Group::CountBits(bool v)
284 {
285 return Ballot(v).CountBits();
286 }
287}
288#endif