A game about forced loneliness, made by TACStudios
at master 288 lines 14 kB view raw
1#ifndef THREADING_EMU_IMPL 2#define THREADING_EMU_IMPL 3 4// If the user didn't specify a wave size, we assume that their code is "wave size independent" and that they don't 5// care which size is actually used. In this case, we automatically select an arbitrary size for them since the 6// emulation logic depends on having *some* known size. 7#ifndef THREADING_WAVE_SIZE 8#define THREADING_WAVE_SIZE 32 9#endif 10 11namespace Threading 12{ 13 // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those. 14 // Support for vector types is currently not there but can be added as needed (and this comment removed). 15 groupshared uint g_Scratch[THREADING_BLOCK_SIZE]; 16 17#define EMULATED_WAVE_REDUCE(TYPE, OP) \ 18 GroupMemoryBarrierWithGroupSync(); \ 19 g_Scratch[indexG] = asuint(v); \ 20 GroupMemoryBarrierWithGroupSync(); \ 21 [unroll] \ 22 for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \ 23 { \ 24 if (indexL < s) \ 25 g_Scratch[indexG] = asuint(as##TYPE(g_Scratch[indexG]) OP as##TYPE(g_Scratch[indexG + s])); \ 26 GroupMemoryBarrierWithGroupSync(); \ 27 } \ 28 return as##TYPE(g_Scratch[offset]); \ 29 30#define EMULATED_WAVE_REDUCE_CMP(TYPE, OP) \ 31 GroupMemoryBarrierWithGroupSync(); \ 32 g_Scratch[indexG] = asuint(v); \ 33 GroupMemoryBarrierWithGroupSync(); \ 34 [unroll] \ 35 for (uint s = THREADING_WAVE_SIZE / 2u; s > 0u; s >>= 1u) \ 36 { \ 37 if (indexL < s) \ 38 g_Scratch[indexG] = asuint(OP(as##TYPE(g_Scratch[indexG]), as##TYPE(g_Scratch[indexG + s]))); \ 39 GroupMemoryBarrierWithGroupSync(); \ 40 } \ 41 return as##TYPE(g_Scratch[offset]); \ 42 43#define EMULATED_WAVE_PREFIX(TYPE, OP, FILL_VALUE) \ 44 GroupMemoryBarrierWithGroupSync(); \ 45 g_Scratch[indexG] = asuint(v); \ 46 GroupMemoryBarrierWithGroupSync(); \ 47 [unroll] \ 48 for (uint s = 1u; s < THREADING_WAVE_SIZE; s <<= 1u) \ 49 { \ 50 TYPE nv = FILL_VALUE; \ 51 if (indexL >= s) \ 52 { \ 53 nv = as##TYPE(g_Scratch[indexG - s]); \ 54 } \ 55 nv = as##TYPE(g_Scratch[indexG]) OP nv; \ 56 GroupMemoryBarrierWithGroupSync(); \ 57 g_Scratch[indexG] = asuint(nv); \ 58 GroupMemoryBarrierWithGroupSync(); \ 59 } \ 60 TYPE result = FILL_VALUE; \ 61 if (indexL > 0u) \ 62 result = as##TYPE(g_Scratch[indexG - 1]); \ 63 return result; \ 64 65 uint Wave::GetIndex() { return indexW; } 66 67 void Wave::Init(uint groupIndex) 68 { 69 indexG = groupIndex; 70 indexW = indexG / THREADING_WAVE_SIZE; 71 indexL = indexG & (THREADING_WAVE_SIZE - 1); 72 offset = indexW * THREADING_WAVE_SIZE; 73 } 74 75 // WARNING: 76 // These emulated functions do not emulate the execution mask. 77 // So they WILL produce incorrect results if you have divergent lanes. 78 79 #define DEFINE_API_FOR_TYPE(TYPE) \ 80 bool Wave::AllEqual(TYPE v) { return AllTrue(ReadLaneFirst(v) == v); } \ 81 TYPE Wave::Product(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, *) } \ 82 TYPE Wave::Sum(TYPE v) { EMULATED_WAVE_REDUCE(TYPE, +) } \ 83 TYPE Wave::Max(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, max) } \ 84 TYPE Wave::Min(TYPE v) { EMULATED_WAVE_REDUCE_CMP(TYPE, min) } \ 85 TYPE Wave::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \ 86 TYPE Wave::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \ 87 TYPE Wave::PrefixSum (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, +, (TYPE)0) } \ 88 TYPE Wave::PrefixProduct (TYPE v) { EMULATED_WAVE_PREFIX(TYPE, *, (TYPE)1) } \ 89 TYPE Wave::ReadLaneAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[indexG] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[offset + i]); } \ 90 TYPE Wave::ReadLaneFirst(TYPE v) { return ReadLaneAt(v, 0u); } \ 91 92 // Currently just support scalars. 93 DEFINE_API_FOR_TYPE(uint) 94 DEFINE_API_FOR_TYPE(int) 95 DEFINE_API_FOR_TYPE(float) 96 97 // The following emulated functions need only be declared once. 98 uint Wave::GetLaneCount() { return THREADING_WAVE_SIZE; } 99 uint Wave::GetLaneIndex() { return indexL; } 100 bool Wave::IsFirstLane() { return indexL == 0u; } 101 bool Wave::AllTrue(bool v) { return And(v) != 0u; } 102 bool Wave::AnyTrue(bool v) { return Or (v) != 0u; } 103 uint Wave::PrefixCountBits(bool v) { return PrefixSum((uint)v); } 104 uint Wave::And(uint v) { EMULATED_WAVE_REDUCE(uint, &) } 105 uint Wave::Or (uint v) { EMULATED_WAVE_REDUCE(uint, |) } 106 uint Wave::Xor(uint v) { EMULATED_WAVE_REDUCE(uint, ^) } 107 108 uint4 Wave::Ballot(bool v) 109 { 110 uint indexDw = indexL % 32u; 111 uint offsetDw = (indexL / 32u) * 32u; 112 uint indexScratch = offset + offsetDw + indexDw; 113 114 GroupMemoryBarrierWithGroupSync(); 115 116 g_Scratch[indexG] = v << indexDw; 117 118 GroupMemoryBarrierWithGroupSync(); 119 120 [unroll] 121 for (uint s = min(THREADING_WAVE_SIZE / 2u, 16u); s > 0u; s >>= 1u) 122 { 123 if (indexDw < s) 124 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s]; 125 126 GroupMemoryBarrierWithGroupSync(); 127 } 128 129 uint4 result = uint4(g_Scratch[offset], 0, 0, 0); 130 131#if THREADING_WAVE_SIZE > 32 132 result.y = g_Scratch[offset + 32]; 133#endif 134 135#if THREADING_WAVE_SIZE > 64 136 result.z = g_Scratch[offset + 64]; 137#endif 138 139#if THREADING_WAVE_SIZE > 96 140 result.w = g_Scratch[offset + 96]; 141#endif 142 143 return result; 144 } 145 146 uint Wave::CountBits(bool v) 147 { 148 uint4 ballot = Ballot(v); 149 150 uint result = countbits(ballot.x); 151 152#if THREADING_WAVE_SIZE > 32 153 result += countbits(ballot.y); 154#endif 155 156#if THREADING_WAVE_SIZE > 64 157 result += countbits(ballot.z); 158#endif 159 160#if THREADING_WAVE_SIZE > 96 161 result += countbits(ballot.w); 162#endif 163 164 return result; 165 } 166 167#define EMULATED_GROUP_REDUCE(TYPE, OP) \ 168 GroupMemoryBarrierWithGroupSync(); \ 169 g_Scratch[groupIndex] = asuint(v); \ 170 GroupMemoryBarrierWithGroupSync(); \ 171 [unroll] \ 172 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \ 173 { \ 174 if (groupIndex < s) \ 175 g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \ 176 GroupMemoryBarrierWithGroupSync(); \ 177 } \ 178 return as##TYPE(g_Scratch[0]); \ 179 180#define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \ 181 GroupMemoryBarrierWithGroupSync(); \ 182 g_Scratch[groupIndex] = asuint(v); \ 183 GroupMemoryBarrierWithGroupSync(); \ 184 [unroll] \ 185 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \ 186 { \ 187 if (groupIndex < s) \ 188 g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \ 189 GroupMemoryBarrierWithGroupSync(); \ 190 } \ 191 return as##TYPE(g_Scratch[0]); \ 192 193#define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \ 194 GroupMemoryBarrierWithGroupSync(); \ 195 g_Scratch[groupIndex] = asuint(v); \ 196 GroupMemoryBarrierWithGroupSync(); \ 197 [unroll] \ 198 for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \ 199 { \ 200 TYPE nv = FILL_VALUE; \ 201 if (groupIndex >= s) \ 202 { \ 203 nv = as##TYPE(g_Scratch[groupIndex - s]); \ 204 } \ 205 nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \ 206 GroupMemoryBarrierWithGroupSync(); \ 207 g_Scratch[groupIndex] = asuint(nv); \ 208 GroupMemoryBarrierWithGroupSync(); \ 209 } \ 210 TYPE result = FILL_VALUE; \ 211 if (groupIndex > 0u) \ 212 result = as##TYPE(g_Scratch[groupIndex - 1]); \ 213 return result; \ 214 215 uint Group::GetWaveCount() 216 { 217 return THREADING_BLOCK_SIZE / THREADING_WAVE_SIZE; 218 } 219 220 #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \ 221 bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \ 222 TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \ 223 TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \ 224 TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \ 225 TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \ 226 TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \ 227 TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \ 228 TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \ 229 TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \ 230 TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \ 231 TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \ 232 TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \ 233 234 // Currently just support scalars. 235 DEFINE_API_FOR_TYPE_GROUP(uint) 236 DEFINE_API_FOR_TYPE_GROUP(int) 237 DEFINE_API_FOR_TYPE_GROUP(float) 238 239 // The following emulated functions need only be declared once. 240 uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; } 241 uint Group::GetThreadIndex() { return groupIndex; } 242 bool Group::IsFirstThread() { return groupIndex == 0u; } 243 bool Group::AllTrue(bool v) { return And(v) != 0u; } 244 bool Group::AnyTrue(bool v) { return Or (v) != 0u; } 245 uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); } 246 uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) } 247 uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) } 248 uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) } 249 250 GroupBallot Group::Ballot(bool v) 251 { 252 uint indexDw = groupIndex % 32u; 253 uint offsetDw = (groupIndex / 32u) * 32u; 254 uint indexScratch = offsetDw + indexDw; 255 256 GroupMemoryBarrierWithGroupSync(); 257 258 g_Scratch[groupIndex] = v << indexDw; 259 260 GroupMemoryBarrierWithGroupSync(); 261 262 [unroll] 263 for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u) 264 { 265 if (indexDw < s) 266 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s]; 267 268 GroupMemoryBarrierWithGroupSync(); 269 } 270 271 GroupBallot ballot = (GroupBallot)0; 272 273 // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference 274 [unroll] 275 for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex) 276 { 277 ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32]; 278 } 279 280 return ballot; 281 } 282 283 uint Group::CountBits(bool v) 284 { 285 return Ballot(v).CountBits(); 286 } 287} 288#endif