A game about forced loneliness, made by TACStudios
at master 171 lines 9.1 kB view raw
1#ifndef THREADING_SM6_IMPL 2#define THREADING_SM6_IMPL 3 4namespace Threading 5{ 6 // Currently we only cover scalar types as at the time of writing this utility library we only needed emulation for those. 7 // Support for vector types is currently not there but can be added as needed (and this comment removed). 8 groupshared uint g_Scratch[THREADING_BLOCK_SIZE]; 9 10 uint Wave::GetIndex() { return indexW; } 11 12 void Wave::Init(uint groupIndex) 13 { 14 indexG = groupIndex; 15 indexW = indexG / GetLaneCount(); 16 } 17 18 // Note: The HLSL intrinsics should be correctly replaced by console-specific intrinsics by our API library. 19 #define DEFINE_API_FOR_TYPE(TYPE) \ 20 bool Wave::AllEqual(TYPE v) { return WaveActiveAllEqual(v); } \ 21 TYPE Wave::Product(TYPE v) { return WaveActiveProduct(v); } \ 22 TYPE Wave::Sum(TYPE v) { return WaveActiveSum(v); } \ 23 TYPE Wave::Max(TYPE v) { return WaveActiveMax(v); } \ 24 TYPE Wave::Min(TYPE v) { return WaveActiveMin(v); } \ 25 TYPE Wave::InclusivePrefixSum (TYPE v) { return WavePrefixSum(v) + v; } \ 26 TYPE Wave::InclusivePrefixProduct (TYPE v) { return WavePrefixProduct(v) * v; } \ 27 TYPE Wave::PrefixSum(TYPE v) { return WavePrefixSum(v); } \ 28 TYPE Wave::PrefixProduct(TYPE v) { return WavePrefixProduct(v); } \ 29 TYPE Wave::ReadLaneAt(TYPE v, uint i) { return WaveReadLaneAt(v, i); } \ 30 TYPE Wave::ReadLaneFirst(TYPE v) { return WaveReadLaneFirst(v); } \ 31 32 // Currently just support scalars. 33 DEFINE_API_FOR_TYPE(uint) 34 DEFINE_API_FOR_TYPE(int) 35 DEFINE_API_FOR_TYPE(float) 36 37 // The following intrinsics need only be declared once. 38 uint Wave::GetLaneCount() { return WaveGetLaneCount(); } 39 uint Wave::GetLaneIndex() { return WaveGetLaneIndex(); } 40 bool Wave::IsFirstLane() { return WaveIsFirstLane(); } 41 bool Wave::AllTrue(bool v) { return WaveActiveAllTrue(v); } 42 bool Wave::AnyTrue(bool v) { return WaveActiveAnyTrue(v); } 43 uint4 Wave::Ballot(bool v) { return WaveActiveBallot(v); } 44 uint Wave::CountBits(bool v) { return WaveActiveCountBits(v); } 45 uint Wave::PrefixCountBits(bool v) { return WavePrefixCountBits(v); } 46 uint Wave::And(uint v) { return WaveActiveBitAnd(v); } 47 uint Wave::Or (uint v) { return WaveActiveBitOr(v); } 48 uint Wave::Xor(uint v) { return WaveActiveBitXor(v); } 49 50#define EMULATED_GROUP_REDUCE(TYPE, OP) \ 51 GroupMemoryBarrierWithGroupSync(); \ 52 g_Scratch[groupIndex] = asuint(v); \ 53 GroupMemoryBarrierWithGroupSync(); \ 54 [unroll] \ 55 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \ 56 { \ 57 if (groupIndex < s) \ 58 g_Scratch[groupIndex] = asuint(as##TYPE(g_Scratch[groupIndex]) OP as##TYPE(g_Scratch[groupIndex + s])); \ 59 GroupMemoryBarrierWithGroupSync(); \ 60 } \ 61 return as##TYPE(g_Scratch[0]); \ 62 63#define EMULATED_GROUP_REDUCE_CMP(TYPE, OP) \ 64 GroupMemoryBarrierWithGroupSync(); \ 65 g_Scratch[groupIndex] = asuint(v); \ 66 GroupMemoryBarrierWithGroupSync(); \ 67 [unroll] \ 68 for (uint s = THREADING_BLOCK_SIZE / 2u; s > 0u; s >>= 1u) \ 69 { \ 70 if (groupIndex < s) \ 71 g_Scratch[groupIndex] = asuint(OP(as##TYPE(g_Scratch[groupIndex]), as##TYPE(g_Scratch[groupIndex + s]))); \ 72 GroupMemoryBarrierWithGroupSync(); \ 73 } \ 74 return as##TYPE(g_Scratch[0]); \ 75 76#define EMULATED_GROUP_PREFIX(TYPE, OP, FILL_VALUE) \ 77 GroupMemoryBarrierWithGroupSync(); \ 78 g_Scratch[groupIndex] = asuint(v); \ 79 GroupMemoryBarrierWithGroupSync(); \ 80 [unroll] \ 81 for (uint s = 1u; s < THREADING_BLOCK_SIZE; s <<= 1u) \ 82 { \ 83 TYPE nv = FILL_VALUE; \ 84 if (groupIndex >= s) \ 85 { \ 86 nv = as##TYPE(g_Scratch[groupIndex - s]); \ 87 } \ 88 nv = as##TYPE(g_Scratch[groupIndex]) OP nv; \ 89 GroupMemoryBarrierWithGroupSync(); \ 90 g_Scratch[groupIndex] = asuint(nv); \ 91 GroupMemoryBarrierWithGroupSync(); \ 92 } \ 93 TYPE result = FILL_VALUE; \ 94 if (groupIndex > 0u) \ 95 result = as##TYPE(g_Scratch[groupIndex - 1]); \ 96 return result; \ 97 98 uint Group::GetWaveCount() 99 { 100 return THREADING_BLOCK_SIZE / WaveGetLaneCount(); 101 } 102 103 #define DEFINE_API_FOR_TYPE_GROUP(TYPE) \ 104 bool Group::AllEqual(TYPE v) { return AllTrue(ReadThreadFirst(v) == v); } \ 105 TYPE Group::Product(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, *) } \ 106 TYPE Group::Sum(TYPE v) { EMULATED_GROUP_REDUCE(TYPE, +) } \ 107 TYPE Group::Max(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, max) } \ 108 TYPE Group::Min(TYPE v) { EMULATED_GROUP_REDUCE_CMP(TYPE, min) } \ 109 TYPE Group::InclusivePrefixSum (TYPE v) { return PrefixSum(v) + v; } \ 110 TYPE Group::InclusivePrefixProduct (TYPE v) { return PrefixProduct(v) * v; } \ 111 TYPE Group::PrefixSum (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, +, (TYPE)0) } \ 112 TYPE Group::PrefixProduct (TYPE v) { EMULATED_GROUP_PREFIX(TYPE, *, (TYPE)1) } \ 113 TYPE Group::ReadThreadAt(TYPE v, uint i) { GroupMemoryBarrierWithGroupSync(); g_Scratch[groupIndex] = asuint(v); GroupMemoryBarrierWithGroupSync(); return as##TYPE(g_Scratch[i]); } \ 114 TYPE Group::ReadThreadFirst(TYPE v) { return ReadThreadAt(v, 0u); } \ 115 TYPE Group::ReadThreadShuffle(TYPE v, uint i) { return ReadThreadAt(v, i); } \ 116 117 // Currently just support scalars. 118 DEFINE_API_FOR_TYPE_GROUP(uint) 119 DEFINE_API_FOR_TYPE_GROUP(int) 120 DEFINE_API_FOR_TYPE_GROUP(float) 121 122 // The following emulated functions need only be declared once. 123 uint Group::GetThreadCount() { return THREADING_BLOCK_SIZE; } 124 uint Group::GetThreadIndex() { return groupIndex; } 125 bool Group::IsFirstThread() { return groupIndex == 0u; } 126 bool Group::AllTrue(bool v) { return And(v) != 0u; } 127 bool Group::AnyTrue(bool v) { return Or (v) != 0u; } 128 uint Group::PrefixCountBits(bool v) { return PrefixSum((uint)v); } 129 uint Group::And(uint v) { EMULATED_GROUP_REDUCE(uint, &) } 130 uint Group::Or (uint v) { EMULATED_GROUP_REDUCE(uint, |) } 131 uint Group::Xor(uint v) { EMULATED_GROUP_REDUCE(uint, ^) } 132 133 GroupBallot Group::Ballot(bool v) 134 { 135 uint indexDw = groupIndex % 32u; 136 uint offsetDw = (groupIndex / 32u) * 32u; 137 uint indexScratch = offsetDw + indexDw; 138 139 GroupMemoryBarrierWithGroupSync(); 140 141 g_Scratch[groupIndex] = v << indexDw; 142 143 GroupMemoryBarrierWithGroupSync(); 144 145 [unroll] 146 for (uint s = min(THREADING_BLOCK_SIZE / 2u, 16u); s > 0u; s >>= 1u) 147 { 148 if (indexDw < s) 149 g_Scratch[indexScratch] = g_Scratch[indexScratch] | g_Scratch[indexScratch + s]; 150 151 GroupMemoryBarrierWithGroupSync(); 152 } 153 154 GroupBallot ballot = (GroupBallot)0; 155 156 // Explicitly mark this loop as "unroll" to avoid warnings about assigning to an array reference 157 [unroll] 158 for (uint dwordIndex = 0; dwordIndex < _THREADING_GROUP_BALLOT_DWORDS; ++dwordIndex) 159 { 160 ballot.dwords[dwordIndex] = g_Scratch[dwordIndex * 32]; 161 } 162 163 return ballot; 164 } 165 166 uint Group::CountBits(bool v) 167 { 168 return Ballot(v).CountBits(); 169 } 170} 171#endif